diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2e7c543 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +*.pyc +*.npy +*.pth +*.whl +*.swp +*.sif + +wandb/ +data/ +ckpt/ +work_dirs*/ +dist_test/ +vis/ +val/ +lib/ +logs/ + +*.egg-info +build/ +__pycache__/ +*.so + +job_scripts/ +temp_ops/ \ No newline at end of file diff --git a/LICENSE b/LICENSE index 261eeb9..5c09177 100644 --- a/LICENSE +++ b/LICENSE @@ -1,201 +1,21 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +MIT License + +Copyright (c) 2024 swc-17 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md deleted file mode 100644 index b1f1ebe..0000000 --- a/README.md +++ /dev/null @@ -1,8 +0,0 @@ -
-

ForeSight

-

[ICCV2025] ForeSight: Multi-View Streaming Joint Object Detection and Trajectory Forecasting

-
- -[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2508.07089) - -Code coming soon. diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..1f3eb07 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,22 @@ +FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel +ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \ + TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ + CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \ + FORCE_CUDA="1" + +# Install the MMDetection3D required packages +RUN apt-get update \ + && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# Install StreamPETR required packages +WORKDIR /workspace/ForeSight/ +COPY requirement.txt . +COPY projects/ projects/ +RUN pip install mmcv-full==1.7.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html +RUN pip install --no-cache-dir -r requirement.txt +RUN cd projects/mmdet3d_plugin/ops && python setup.py develop +RUN git config --global --add safe.directory '*' + +WORKDIR /workspace/ForeSight/ \ No newline at end of file diff --git a/docs/quick_start.md b/docs/quick_start.md new file mode 100644 index 0000000..0ad0e8f --- /dev/null +++ b/docs/quick_start.md @@ -0,0 +1,64 @@ +# Quick Start + +### Set up a new virtual environment +```bash +conda create -n sparsedrive python=3.8 -y +conda activate sparsedrive +``` + +### Install dependency packpages +```bash +sparsedrive_path="path/to/sparsedrive" +cd ${sparsedrive_path} +pip3 install --upgrade pip +pip3 install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116 +pip3 install -r requirement.txt +``` + +### Compile the deformable_aggregation CUDA op +```bash +cd projects/mmdet3d_plugin/ops +python3 setup.py develop +cd ../../../ +``` + +### Prepare the data +Download the [NuScenes dataset](https://www.nuscenes.org/nuscenes#download) and CAN bus expansion, put CAN bus expansion in /path/to/nuscenes, create symbolic links. +```bash +cd ${sparsedrive_path} +mkdir data +ln -s path/to/nuscenes ./data/nuscenes +``` + +Pack the meta-information and labels of the dataset, and generate the required pkl files to data/infos. Note that we also generate map_annos in data_converter, with a roi_size of (30, 60) as default, if you want a different range, you can modify roi_size in tools/data_converter/nuscenes_converter.py. +```bash +sh scripts/create_data.sh +``` + +### Generate anchors by K-means +Gnerated anchors are saved to data/kmeans and can be visualized in vis/kmeans. +```bash +sh scripts/kmeans.sh +``` + + +### Download pre-trained weights +Download the required backbone [pre-trained weights](https://download.pytorch.org/models/resnet50-19c8e357.pth). +```bash +mkdir ckpt +wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O ckpt/resnet50-19c8e357.pth +``` + +### Commence training and testing +```bash +# train +sh scripts/train.sh + +# test +sh scripts/test.sh +``` + +### Visualization +``` +sh scripts/visualize.sh +``` diff --git a/projects/configs/sparsedrive_r101_stage1_4gpu.py b/projects/configs/sparsedrive_r101_stage1_4gpu.py new file mode 100644 index 0000000..b6d5212 --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage1_4gpu.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 80 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage1_4gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage1_8gpu.py b/projects/configs/sparsedrive_r101_stage1_8gpu.py new file mode 100644 index 0000000..3985938 --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage1_8gpu.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 80 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage1_8gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage1_8gpu_noflash.py b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash.py new file mode 100644 index 0000000..d17cd5f --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 80 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage1_8gpu_noflash',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage1_8gpu_noflash_dn.py b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash_dn.py new file mode 100644 index 0000000..f84de5f --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash_dn.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 80 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage1_8gpu_noflash_dn',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=5, + num_temp_dn_groups=3, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage2_4gpu.py b/projects/configs/sparsedrive_r101_stage2_4gpu.py new file mode 100644 index 0000000..54c06d4 --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage2_4gpu.py @@ -0,0 +1,728 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage2_4gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_r101_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage2_4gpu_nomap.py b/projects/configs/sparsedrive_r101_stage2_4gpu_nomap.py new file mode 100644 index 0000000..73d381f --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage2_4gpu_nomap.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage2_4gpu_nomap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_r101_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage2_8gpu.py b/projects/configs/sparsedrive_r101_stage2_8gpu.py new file mode 100644 index 0000000..0fb3daa --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage2_8gpu.py @@ -0,0 +1,728 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage2_8gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_r101_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage2_8gpu_noflash.py b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash.py new file mode 100644 index 0000000..1a1a203 --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash.py @@ -0,0 +1,728 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage2_8gpu_noflash',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_r101_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r101_stage2_8gpu_noflash_dn.py b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash_dn.py new file mode 100644 index 0000000..34da3ea --- /dev/null +++ b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash_dn.py @@ -0,0 +1,728 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 32 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r101_stage2_8gpu_noflash_dn',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (1408, 512) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [8, 16, 32, 64] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=101, + num_stages=4, + frozen_stages=-1, + norm_eval=True, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=False), + init_cfg=dict( + type='Pretrained', + checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth', + prefix='backbone.')), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32] + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=5, + num_temp_dn_groups=3, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.80, 0.94), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_r101_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_4gpu.py b/projects/configs/sparsedrive_r50_stage1_4gpu.py new file mode 100644 index 0000000..aee18a1 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_4gpu.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_4gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu.py b/projects/configs/sparsedrive_r50_stage1_8gpu.py new file mode 100644 index 0000000..fdc93cb --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash.py new file mode 100644 index 0000000..475b82a --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn.py new file mode 100644 index 0000000..4bf69d4 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_dn',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=5, + num_temp_dn_groups=3, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug.py new file mode 100644 index 0000000..e8859c4 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=5, + num_temp_dn_groups=3, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [-0.3925, 0.3925], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap.py new file mode 100644 index 0000000..1324b8f --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap.py @@ -0,0 +1,742 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_gtdetmap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=False, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap_deform.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap_deform.py new file mode 100644 index 0000000..3ddf06b --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap_deform.py @@ -0,0 +1,768 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_gtdetmap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="add", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=False, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap.py new file mode 100644 index 0000000..aac93f2 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap.py @@ -0,0 +1,722 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_nomap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn.py new file mode 100644 index 0000000..9896ead --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn.py @@ -0,0 +1,722 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_nomap_dn',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=5, + num_temp_dn_groups=3, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug.py new file mode 100644 index 0000000..cacdb33 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug.py @@ -0,0 +1,722 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=5, + num_temp_dn_groups=3, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [-0.3925, 0.3925], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug.py new file mode 100644 index 0000000..25afc6f --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug.py @@ -0,0 +1,722 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [-0.3925, 0.3925], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_1gpu.py b/projects/configs/sparsedrive_r50_stage2_1gpu.py new file mode 100644 index 0000000..fdaf380 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_1gpu.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 8 +num_gpus = 1 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_1gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=4, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu.py b/projects/configs/sparsedrive_r50_stage2_4gpu.py new file mode 100644 index 0000000..0259073 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24.py new file mode 100644 index 0000000..7feb7bd --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_gtdetmap.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_gtdetmap.py new file mode 100644 index 0000000..2c6fe7b --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_gtdetmap.py @@ -0,0 +1,745 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_gtdetmap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + num_map_classes=num_map_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=False, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap.py new file mode 100644 index 0000000..4aed9f5 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_nomap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead.py new file mode 100644 index 0000000..378c29c --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead.py @@ -0,0 +1,743 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfhtraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfhtraineval.py new file mode 100644 index 0000000..17466ac --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfhtraineval.py @@ -0,0 +1,746 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfheval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_occ.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfheval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfheval.py new file mode 100644 index 0000000..b70c948 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfheval.py @@ -0,0 +1,728 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_occfheval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_cvocc.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_cvocc.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval.py new file mode 100644 index 0000000..d0b45b6 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval.py @@ -0,0 +1,729 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_cvocc.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_cvocc.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfleval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfleval.py new file mode 100644 index 0000000..7a7b838 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfleval.py @@ -0,0 +1,728 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_occfleval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_occ.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_occ.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfltraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfltraineval.py new file mode 100644 index 0000000..24c3b45 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfltraineval.py @@ -0,0 +1,729 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_occfltraineval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_occ.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_occ.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occpeval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occpeval.py new file mode 100644 index 0000000..0ed163a --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occpeval.py @@ -0,0 +1,728 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_occpeval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occptraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occptraineval.py new file mode 100644 index 0000000..8674701 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occptraineval.py @@ -0,0 +1,729 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_occptraineval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly.py new file mode 100644 index 0000000..cd91b47 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly.py @@ -0,0 +1,732 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ca.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ca.py new file mode 100644 index 0000000..88ceb84 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ca.py @@ -0,0 +1,549 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_ca',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + +# DDP: backbone/neck receive gradients but CV motion head has no trainable +# parameters, so mark unused parameters to avoid DDP sync errors. +find_unused_parameters = True + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + motion_plan_head=dict( + type='KinematicMotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_acceleration=True, + use_turn_rate=False, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra.py new file mode 100644 index 0000000..d74cb63 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra.py @@ -0,0 +1,549 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + +# DDP: backbone/neck receive gradients but CV motion head has no trainable +# parameters, so mark unused parameters to avoid DDP sync errors. +find_unused_parameters = True + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + motion_plan_head=dict( + type='KinematicMotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_acceleration=True, + use_turn_rate=True, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv.py new file mode 100644 index 0000000..5a0e4d8 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv.py @@ -0,0 +1,549 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + +# DDP: backbone/neck receive gradients but CV motion head has no trainable +# parameters, so mark unused parameters to avoid DDP sync errors. +find_unused_parameters = True + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + motion_plan_head=dict( + type='KinematicMotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_acceleration=False, + use_turn_rate=True, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_cv.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_cv.py new file mode 100644 index 0000000..cc52505 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_cv.py @@ -0,0 +1,549 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_cv',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + +# DDP: backbone/neck receive gradients but CV motion head has no trainable +# parameters, so mark unused parameters to avoid DDP sync errors. +find_unused_parameters = True + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + motion_plan_head=dict( + type='KinematicMotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_acceleration=False, + use_turn_rate=False, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_deform.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_deform.py new file mode 100644 index 0000000..82a2691 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_deform.py @@ -0,0 +1,762 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_deform',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + # Deformable cross-attention to sensor (image) features inserted + # after the GNN self-attention block in each of the 3 decoder + # iterations. residual_mode="add" keeps embed_dims unchanged so + # the existing AsymmetricFFN (in_channels=embed_dims) is compatible. + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="add", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand.py new file mode 100644 index 0000000..a14fe94 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand.py @@ -0,0 +1,731 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform.py new file mode 100644 index 0000000..91c1a9a --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform.py @@ -0,0 +1,761 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + # Deformable cross-attention to sensor (image) features inserted + # after the GNN self-attention block in each of the 3 decoder + # iterations. residual_mode="add" keeps embed_dims unchanged so + # the existing AsymmetricFFN (in_channels=embed_dims) is compatible. + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="add", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_refine.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_refine.py new file mode 100644 index 0000000..52d18e0 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_refine.py @@ -0,0 +1,730 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_predonly_refine',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=False, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="GTSparseDriveHead", + task_config=task_config, + num_classes=num_classes, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + "refine", + ] * 3 + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + "gt_bboxes_3d", + "gt_labels_3d", + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=False, + with_tracking=False, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2.py new file mode 100644 index 0000000..0a2f882 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_pt2',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1_dn.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead.py new file mode 100644 index 0000000..71f8dec --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead.py @@ -0,0 +1,745 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1_dn.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occfltrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occfltrainval.py new file mode 100644 index 0000000..ea91cbb --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occfltrainval.py @@ -0,0 +1,748 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=False, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_occ.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_occ.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval.py new file mode 100644 index 0000000..96806a9 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval.py @@ -0,0 +1,748 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso.py new file mode 100644 index 0000000..aac9b01 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso.py @@ -0,0 +1,757 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_supervise_all=True, + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=False, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=False, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.2, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval.py new file mode 100644 index 0000000..f3f6aa6 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval.py @@ -0,0 +1,759 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_supervise_all=True, + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=False, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=False, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.2, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval.py new file mode 100644 index 0000000..f15be14 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval.py @@ -0,0 +1,758 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.2, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval.py new file mode 100644 index 0000000..3a00736 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval.py @@ -0,0 +1,759 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_supervise_all=True, + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.4, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_occ.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_occ.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval.py new file mode 100644 index 0000000..23c85e4 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval.py @@ -0,0 +1,759 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_supervise_all=True, + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.2, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_validflag.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_validflag.py new file mode 100644 index 0000000..521d456 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_validflag.py @@ -0,0 +1,729 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_validflag',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_valid_flag=True + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + use_valid_flag=True + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + use_valid_flag=True + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval.py new file mode 100644 index 0000000..6463489 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval.py @@ -0,0 +1,738 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.4, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_cvocc.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_cvocc.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval.py new file mode 100644 index 0000000..6afde0f --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval.py @@ -0,0 +1,738 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 24 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.2, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=1.5e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv2.py b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv2.py new file mode 100644 index 0000000..7a5da18 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv2.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_maplrdiv2',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=5.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv4.py b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv4.py new file mode 100644 index 0000000..4c654b5 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv4.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_maplrdiv4',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.25, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=2.5, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap.py new file mode 100644 index 0000000..59e5e30 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan.py new file mode 100644 index 0000000..a567ccc --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap_noplan',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.0, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=0.0), + plan_loss_status=dict(type='L1Loss', loss_weight=0.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion.py new file mode 100644 index 0000000..79cac29 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion.py @@ -0,0 +1,723 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 1 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.0, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=0.0), + plan_loss_status=dict(type='L1Loss', loss_weight=0.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_notempmotion.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_notempmotion.py new file mode 100644 index 0000000..b10e732 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_notempmotion.py @@ -0,0 +1,723 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap_notempmotion',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 1 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine2.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine2.py new file mode 100644 index 0000000..34d6a20 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine2.py @@ -0,0 +1,729 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap_predrefine2',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] + + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + "refine", + ] * 2 + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine3.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine3.py new file mode 100644 index 0000000..3dd0a80 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine3.py @@ -0,0 +1,722 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap_predrefine3',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + "refine", + ] * 3 + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval.py new file mode 100644 index 0000000..4408418 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval.py @@ -0,0 +1,757 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_supervise_all=True, + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.4, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val_occ.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train_occ.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val_occ.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval.py new file mode 100644 index 0000000..0bb8eef --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval.py @@ -0,0 +1,757 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 4 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"), + warmup_supervise_all=True, + warmup_ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + warmup_refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_cls_branch=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + with_visibility_estimation=True, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + loss_visibility=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=0.0, + alpha=0.2, + loss_weight=1.0, + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + 'gt_visibility', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes', + 'fut_boxes_occluded', + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + use_gt_mask=False, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + with_occlusion=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu.py b/projects/configs/sparsedrive_r50_stage2_8gpu.py new file mode 100644 index 0000000..ba04b85 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_8gpu.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_8gpu',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_noflash.py b/projects/configs/sparsedrive_r50_stage2_8gpu_noflash.py new file mode 100644 index 0000000..e44ef22 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_8gpu_noflash.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_8gpu_noflash',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_noplan.py b/projects/configs/sparsedrive_r50_stage2_8gpu_noplan.py new file mode 100644 index 0000000..956b68c --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_8gpu_noplan.py @@ -0,0 +1,726 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_8gpu_noplan',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.0, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=0.0), + plan_loss_status=dict(type='L1Loss', loss_weight=0.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_notempmotion.py b/projects/configs/sparsedrive_r50_stage2_8gpu_notempmotion.py new file mode 100644 index 0000000..eb3cba8 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_8gpu_notempmotion.py @@ -0,0 +1,725 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_8gpu_notempmotion',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 1 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap.py b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap.py new file mode 100644 index 0000000..00917cc --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1_nomap_dn_rotaug.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv2_noflash.py b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv2_noflash.py new file mode 100644 index 0000000..d4a6929 --- /dev/null +++ b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv2_noflash.py @@ -0,0 +1,724 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type='WandbLoggerHook', + init_kwargs=dict( + entity='trailab', + project='ForeSight', + name='sparsedrive_r50_stage2_8gpu_pretrainv2_noflash',), + interval=50) + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=False, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=6, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=False, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1_dn.pth' \ No newline at end of file diff --git a/projects/configs/sparsedrive_small_stage1.py b/projects/configs/sparsedrive_small_stage1.py new file mode 100644 index 0000000..1cbb38b --- /dev/null +++ b/projects/configs/sparsedrive_small_stage1.py @@ -0,0 +1,719 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 64 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 100 +checkpoint_epoch_interval = 20 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type="TensorboardLoggerHook"), + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=False, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=0 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=batch_size, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=4e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.5), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=False, + with_planning=False, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) \ No newline at end of file diff --git a/projects/configs/sparsedrive_small_stage2.py b/projects/configs/sparsedrive_small_stage2.py new file mode 100644 index 0000000..94cd085 --- /dev/null +++ b/projects/configs/sparsedrive_small_stage2.py @@ -0,0 +1,721 @@ +# ================ base config =================== +version = 'mini' +version = 'trainval' +length = {'trainval': 28130, 'mini': 323} + +plugin = True +plugin_dir = "projects/mmdet3d_plugin/" +dist_params = dict(backend="nccl") +log_level = "INFO" +work_dir = None + +total_batch_size = 48 +num_gpus = 8 +batch_size = total_batch_size // num_gpus +num_iters_per_epoch = int(length[version] // (num_gpus * batch_size)) +num_epochs = 10 +checkpoint_epoch_interval = 10 + +checkpoint_config = dict( + interval=num_iters_per_epoch * checkpoint_epoch_interval +) +log_config = dict( + interval=51, + hooks=[ + dict(type="TextLoggerHook", by_epoch=False), + dict(type="TensorboardLoggerHook"), + ], +) +load_from = None +resume_from = None +workflow = [("train", 1)] +fp16 = dict(loss_scale=32.0) +input_shape = (704, 256) + + +# ================== model ======================== +class_names = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] +map_class_names = [ + 'ped_crossing', + 'divider', + 'boundary', +] +num_classes = len(class_names) +num_map_classes = len(map_class_names) +roi_size = (30, 60) + +num_sample = 20 +fut_ts = 12 +fut_mode = 6 +ego_fut_ts = 6 +ego_fut_mode = 6 +queue_length = 4 # history + current + +embed_dims = 256 +num_groups = 8 +num_decoder = 6 +num_single_frame_decoder = 1 +num_single_frame_decoder_map = 1 +use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed +strides = [4, 8, 16, 32] +num_levels = len(strides) +num_depth_layers = 3 +drop_out = 0.1 +temporal = True +temporal_map = True +decouple_attn = True +decouple_attn_map = False +decouple_attn_motion = True +with_quality_estimation = True + +task_config = dict( + with_det=True, + with_map=True, + with_motion_plan=True, +) + +model = dict( + type="SparseDrive", + use_grid_mask=True, + use_deformable_func=use_deformable_func, + img_backbone=dict( + type="ResNet", + depth=50, + num_stages=4, + frozen_stages=-1, + norm_eval=False, + style="pytorch", + with_cp=True, + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type="BN", requires_grad=True), + pretrained="ckpt/resnet50-19c8e357.pth", + ), + img_neck=dict( + type="FPN", + num_outs=num_levels, + start_level=0, + out_channels=embed_dims, + add_extra_convs="on_output", + relu_before_extra_convs=True, + in_channels=[256, 512, 1024, 2048], + ), + depth_branch=dict( # for auxiliary supervision only + type="DenseDepthNet", + embed_dims=embed_dims, + num_depth_layers=num_depth_layers, + loss_weight=0.2, + ), + head=dict( + type="SparseDriveHead", + task_config=task_config, + det_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn, + instance_bank=dict( + type="InstanceBank", + num_anchor=900, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_det_900.npy", + anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"), + num_temp_instances=600 if temporal else -1, + confidence_decay=0.6, + feat_grad=False, + ), + anchor_encoder=dict( + type="SparseBox3DEncoder", + vel_dims=3, + embed_dims=[128, 32, 32, 64] if decouple_attn else 256, + mode="cat" if decouple_attn else "add", + output_fc=not decouple_attn, + in_loops=1, + out_loops=4 if decouple_attn else 2, + ), + num_single_frame_decoder=num_single_frame_decoder, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder) + )[2:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparseBox3DKeyPointsGenerator", + num_learnable_pts=6, + fix_scale=[ + [0, 0, 0], + [0.45, 0, 0], + [-0.45, 0, 0], + [0, 0.45, 0], + [0, -0.45, 0], + [0, 0, 0.45], + [0, 0, -0.45], + ], + ), + ), + refine_layer=dict( + type="SparseBox3DRefinementModule", + embed_dims=embed_dims, + num_cls=num_classes, + refine_yaw=True, + with_quality_estimation=with_quality_estimation, + ), + sampler=dict( + type="SparseBox3DTarget", + num_dn_groups=0, + num_temp_dn_groups=0, + dn_noise_scale=[2.0] * 3 + [0.5] * 7, + max_dn_gt=32, + add_neg_dn=True, + cls_weight=2.0, + box_weight=0.25, + reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4, + cls_wise_reg_weights={ + class_names.index("traffic_cone"): [ + 2.0, + 2.0, + 2.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + }, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0, + ), + loss_reg=dict( + type="SparseBox3DLoss", + loss_box=dict(type="L1Loss", loss_weight=0.25), + loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True), + loss_yawness=dict(type="GaussianFocalLoss"), + cls_allow_reverse=[class_names.index("barrier")], + ), + decoder=dict(type="SparseBox3DDecoder"), + reg_weights=[2.0] * 3 + [1.0] * 7, + ), + map_head=dict( + type="Sparse4DHead", + cls_threshold_to_reg=0.05, + decouple_attn=decouple_attn_map, + instance_bank=dict( + type="InstanceBank", + num_anchor=100, + embed_dims=embed_dims, + anchor="data/kmeans/kmeans_map_100.npy", + anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"), + num_temp_instances=33 if temporal_map else -1, + confidence_decay=0.6, + feat_grad=True, + ), + anchor_encoder=dict( + type="SparsePoint3DEncoder", + embed_dims=embed_dims, + num_sample=num_sample, + ), + num_single_frame_decoder=num_single_frame_decoder_map, + operation_order=( + [ + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * num_single_frame_decoder_map + + [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "ffn", + "norm", + "refine", + ] + * (num_decoder - num_single_frame_decoder_map) + )[:], + temp_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ) + if temporal_map + else None, + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims * 2, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 4, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + deformable_model=dict( + type="DeformableFeatureAggregation", + embed_dims=embed_dims, + num_groups=num_groups, + num_levels=num_levels, + num_cams=6, + attn_drop=0.15, + use_deformable_func=use_deformable_func, + use_camera_embed=True, + residual_mode="cat", + kps_generator=dict( + type="SparsePoint3DKeyPointsGenerator", + embed_dims=embed_dims, + num_sample=num_sample, + num_learnable_pts=3, + fix_height=(0, 0.5, -0.5, 1, -1), + ground_height=-1.84023, # ground height in lidar frame + ), + ), + refine_layer=dict( + type="SparsePoint3DRefinementModule", + embed_dims=embed_dims, + num_sample=num_sample, + num_cls=num_map_classes, + ), + sampler=dict( + type="SparsePoint3DTarget", + assigner=dict( + type='HungarianLinesAssigner', + cost=dict( + type='MapQueriesCost', + cls_cost=dict(type='FocalLossCost', weight=1.0), + reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True), + ), + ), + num_cls=num_map_classes, + num_sample=num_sample, + roi_size=roi_size, + ), + loss_cls=dict( + type="FocalLoss", + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0, + ), + loss_reg=dict( + type="SparseLineLoss", + loss_line=dict( + type='LinesL1Loss', + loss_weight=10.0, + beta=0.01, + ), + num_sample=num_sample, + roi_size=roi_size, + ), + decoder=dict(type="SparsePoint3DDecoder"), + reg_weights=[1.0] * 40, + gt_cls_key="gt_map_labels", + gt_reg_key="gt_map_pts", + gt_id_key="map_instance_id", + with_instance_id=False, + task_prefix='map', + ), + motion_plan_head=dict( + type='MotionPlanningHead', + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy', + plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy', + embed_dims=embed_dims, + decouple_attn=decouple_attn_motion, + instance_queue=dict( + type="InstanceQueue", + embed_dims=embed_dims, + queue_length=queue_length, + tracking_threshold=0.2, + feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]), + ), + operation_order=( + [ + "temp_gnn", + "gnn", + "norm", + "cross_gnn", + "norm", + "ffn", + "norm", + ] * 3 + + [ + "refine", + ] + ), + temp_graph_model=dict( + type="MultiheadAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + cross_graph_model=dict( + type="MultiheadFlashAttention", + embed_dims=embed_dims, + num_heads=num_groups, + batch_first=True, + dropout=drop_out, + ), + norm_layer=dict(type="LN", normalized_shape=embed_dims), + ffn=dict( + type="AsymmetricFFN", + in_channels=embed_dims, + pre_norm=dict(type="LN"), + embed_dims=embed_dims, + feedforward_channels=embed_dims * 2, + num_fcs=2, + ffn_drop=drop_out, + act_cfg=dict(type="ReLU", inplace=True), + ), + refine_layer=dict( + type="MotionPlanningRefinementModule", + embed_dims=embed_dims, + fut_ts=fut_ts, + fut_mode=fut_mode, + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + motion_sampler=dict( + type="MotionTarget", + ), + motion_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2 + ), + motion_loss_reg=dict(type='L1Loss', loss_weight=0.2), + planning_sampler=dict( + type="PlanningTarget", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + ), + plan_loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.5, + ), + plan_loss_reg=dict(type='L1Loss', loss_weight=1.0), + plan_loss_status=dict(type='L1Loss', loss_weight=1.0), + motion_decoder=dict(type="SparseBox3DMotionDecoder"), + planning_decoder=dict( + type="HierarchicalPlanningDecoder", + ego_fut_ts=ego_fut_ts, + ego_fut_mode=ego_fut_mode, + use_rescore=True, + ), + num_det=50, + num_map=10, + ), + ), +) + +# ================== data ======================== +dataset_type = "NuScenes3DDataset" +data_root = "data/nuscenes/" +anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/" +file_client_args = dict(backend="disk") + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True +) +train_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict( + type="LoadPointsFromFile", + coord_type="LIDAR", + load_dim=5, + use_dim=5, + file_client_args=file_client_args, + ), + dict(type="ResizeCropFlipImage"), + dict( + type="MultiScaleDepthMapGenerator", + downsample=strides[:num_depth_layers], + ), + dict(type="BBoxRotation"), + dict(type="PhotoMetricDistortionMultiViewImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=False, + normalize=False, + sample_num=num_sample, + permute=True, + ), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + "gt_depth", + "focal", + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ], + meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"], + ), +] +test_pipeline = [ + dict(type="LoadMultiViewImageFromFiles", to_float32=True), + dict(type="ResizeCropFlipImage"), + dict(type="NormalizeMultiviewImage", **img_norm_cfg), + dict(type="NuScenesSparse4DAdaptor"), + dict( + type="Collect", + keys=[ + "img", + "timestamp", + "projection_mat", + "image_wh", + 'ego_status', + 'gt_ego_fut_cmd', + ], + meta_keys=["T_global", "T_global_inv", "timestamp"], + ), +] +eval_pipeline = [ + dict( + type="CircleObjectRangeFilter", + class_dist_thred=[55] * len(class_names), + ), + dict(type="InstanceNameFilter", classes=class_names), + dict( + type='VectorizeMap', + roi_size=roi_size, + simplify=True, + normalize=False, + ), + dict( + type='Collect', + keys=[ + 'vectors', + "gt_bboxes_3d", + "gt_labels_3d", + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'fut_boxes' + ], + meta_keys=['token', 'timestamp'] + ), +] + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=False, +) + +data_basic_config = dict( + type=dataset_type, + data_root=data_root, + classes=class_names, + map_classes=map_class_names, + modality=input_modality, + version="v1.0-trainval", +) +eval_config = dict( + **data_basic_config, + ann_file=anno_root + 'nuscenes_infos_val.pkl', + pipeline=eval_pipeline, + test_mode=True, +) +data_aug_conf = { + "resize_lim": (0.40, 0.47), + "final_dim": input_shape[::-1], + "bot_pct_lim": (0.0, 0.0), + "rot_lim": (-5.4, 5.4), + "H": 900, + "W": 1600, + "rand_flip": True, + "rot3d_range": [0, 0], +} + +data = dict( + samples_per_gpu=batch_size, + workers_per_gpu=batch_size, + train=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_train.pkl", + pipeline=train_pipeline, + test_mode=False, + data_aug_conf=data_aug_conf, + with_seq_flag=True, + sequences_split_num=2, + keep_consistent_seq_aug=True, + ), + val=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), + test=dict( + **data_basic_config, + ann_file=anno_root + "nuscenes_infos_val.pkl", + pipeline=test_pipeline, + data_aug_conf=data_aug_conf, + test_mode=True, + eval_config=eval_config, + ), +) + +# ================== training ======================== +optimizer = dict( + type="AdamW", + lr=3e-4, + weight_decay=0.001, + paramwise_cfg=dict( + custom_keys={ + "img_backbone": dict(lr_mult=0.1), + } + ), +) +optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2)) +lr_config = dict( + policy="CosineAnnealing", + warmup="linear", + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-3, +) +runner = dict( + type="IterBasedRunner", + max_iters=num_iters_per_epoch * num_epochs, +) + +# ================== eval ======================== +eval_mode = dict( + with_det=True, + with_tracking=True, + with_map=True, + with_motion=True, + with_planning=True, + tracking_threshold=0.2, + motion_threshhold=0.2, +) +evaluation = dict( + interval=num_iters_per_epoch*checkpoint_epoch_interval, + eval_mode=eval_mode, +) +# ================== pretrained model ======================== +load_from = 'ckpt/sparsedrive_stage1.pth' \ No newline at end of file diff --git a/projects/mmdet3d_plugin/__init__.py b/projects/mmdet3d_plugin/__init__.py new file mode 100644 index 0000000..e6f4ea2 --- /dev/null +++ b/projects/mmdet3d_plugin/__init__.py @@ -0,0 +1,4 @@ +from .datasets import * +from .models import * +from .apis import * +from .core.evaluation import * diff --git a/projects/mmdet3d_plugin/apis/__init__.py b/projects/mmdet3d_plugin/apis/__init__.py new file mode 100644 index 0000000..baab0c6 --- /dev/null +++ b/projects/mmdet3d_plugin/apis/__init__.py @@ -0,0 +1,4 @@ +from .train import custom_train_model +from .mmdet_train import custom_train_detector + +# from .test import custom_multi_gpu_test diff --git a/projects/mmdet3d_plugin/apis/mmdet_train.py b/projects/mmdet3d_plugin/apis/mmdet_train.py new file mode 100644 index 0000000..ad6dc60 --- /dev/null +++ b/projects/mmdet3d_plugin/apis/mmdet_train.py @@ -0,0 +1,219 @@ +# --------------------------------------------- +# Copyright (c) OpenMMLab. All rights reserved. +# --------------------------------------------- +# Modified by Zhiqi Li +# --------------------------------------------- +import random +import warnings + +import numpy as np +import torch +import torch.distributed as dist +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import ( + HOOKS, + DistSamplerSeedHook, + EpochBasedRunner, + Fp16OptimizerHook, + OptimizerHook, + build_optimizer, + build_runner, + get_dist_info, +) +from mmcv.utils import build_from_cfg + +from mmdet.core import EvalHook + +from mmdet.datasets import build_dataset, replace_ImageToTensor +from mmdet.utils import get_root_logger +import time +import os.path as osp +from projects.mmdet3d_plugin.datasets.builder import build_dataloader +from projects.mmdet3d_plugin.core.evaluation.eval_hooks import ( + CustomDistEvalHook, +) +from projects.mmdet3d_plugin.datasets import custom_build_dataset + + +def custom_train_detector( + model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None, +): + logger = get_root_logger(cfg.log_level) + + # prepare data loaders + + dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] + # assert len(dataset)==1s + if "imgs_per_gpu" in cfg.data: + logger.warning( + '"imgs_per_gpu" is deprecated in MMDet V2.0. ' + 'Please use "samples_per_gpu" instead' + ) + if "samples_per_gpu" in cfg.data: + logger.warning( + f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' + f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' + f"={cfg.data.imgs_per_gpu} is used in this experiments" + ) + else: + logger.warning( + 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' + f"{cfg.data.imgs_per_gpu} in this experiments" + ) + cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu + + if "runner" in cfg: + runner_type = cfg.runner["type"] + else: + runner_type = "EpochBasedRunner" + data_loaders = [ + build_dataloader( + ds, + cfg.data.samples_per_gpu, + cfg.data.workers_per_gpu, + # cfg.gpus will be ignored if distributed + len(cfg.gpu_ids), + dist=distributed, + seed=cfg.seed, + nonshuffler_sampler=dict( + type="DistributedSampler" + ), # dict(type='DistributedSampler'), + runner_type=runner_type, + ) + for ds in dataset + ] + + # put model on gpus + if distributed: + find_unused_parameters = cfg.get("find_unused_parameters", False) + # Sets the `find_unused_parameters` parameter in + # torch.nn.parallel.DistributedDataParallel + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + find_unused_parameters=find_unused_parameters, + ) + + else: + model = MMDataParallel( + model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids + ) + + # build runner + optimizer = build_optimizer(model, cfg.optimizer) + + if "runner" not in cfg: + cfg.runner = { + "type": "EpochBasedRunner", + "max_epochs": cfg.total_epochs, + } + warnings.warn( + "config is now expected to have a `runner` section, " + "please set `runner` in your config.", + UserWarning, + ) + else: + if "total_epochs" in cfg: + assert cfg.total_epochs == cfg.runner.max_epochs + + runner = build_runner( + cfg.runner, + default_args=dict( + model=model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta, + ), + ) + + # an ugly workaround to make .log and .log.json filenames the same + runner.timestamp = timestamp + + # fp16 setting + fp16_cfg = cfg.get("fp16", None) + if fp16_cfg is not None: + optimizer_config = Fp16OptimizerHook( + **cfg.optimizer_config, **fp16_cfg, distributed=distributed + ) + elif distributed and "type" not in cfg.optimizer_config: + optimizer_config = OptimizerHook(**cfg.optimizer_config) + else: + optimizer_config = cfg.optimizer_config + + # register hooks + runner.register_training_hooks( + cfg.lr_config, + optimizer_config, + cfg.checkpoint_config, + cfg.log_config, + cfg.get("momentum_config", None), + ) + + # register profiler hook + # trace_config = dict(type='tb_trace', dir_name='work_dir') + # profiler_config = dict(on_trace_ready=trace_config) + # runner.register_profiler_hook(profiler_config) + + if distributed: + if isinstance(runner, EpochBasedRunner): + runner.register_hook(DistSamplerSeedHook()) + + # register eval hooks + if validate: + # Support batch_size > 1 in validation + val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1) + if val_samples_per_gpu > 1: + assert False + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.val.pipeline = replace_ImageToTensor( + cfg.data.val.pipeline + ) + val_dataset = custom_build_dataset(cfg.data.val, dict(test_mode=True)) + + val_dataloader = build_dataloader( + val_dataset, + samples_per_gpu=val_samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + nonshuffler_sampler=dict(type="DistributedSampler"), + ) + eval_cfg = cfg.get("evaluation", {}) + eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner" + eval_cfg["jsonfile_prefix"] = osp.join( + "val", + cfg.work_dir, + time.ctime().replace(" ", "_").replace(":", "_"), + ) + eval_hook = CustomDistEvalHook if distributed else EvalHook + runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) + + # user-defined hooks + if cfg.get("custom_hooks", None): + custom_hooks = cfg.custom_hooks + assert isinstance( + custom_hooks, list + ), f"custom_hooks expect list type, but got {type(custom_hooks)}" + for hook_cfg in cfg.custom_hooks: + assert isinstance(hook_cfg, dict), ( + "Each item in custom_hooks expects dict type, but got " + f"{type(hook_cfg)}" + ) + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop("priority", "NORMAL") + hook = build_from_cfg(hook_cfg, HOOKS) + runner.register_hook(hook, priority=priority) + + if cfg.resume_from: + runner.resume(cfg.resume_from) + elif cfg.load_from: + runner.load_checkpoint(cfg.load_from) + runner.run(data_loaders, cfg.workflow) diff --git a/projects/mmdet3d_plugin/apis/test.py b/projects/mmdet3d_plugin/apis/test.py new file mode 100644 index 0000000..f5fdcd3 --- /dev/null +++ b/projects/mmdet3d_plugin/apis/test.py @@ -0,0 +1,171 @@ +# --------------------------------------------- +# Copyright (c) OpenMMLab. All rights reserved. +# --------------------------------------------- +# Modified by Zhiqi Li +# --------------------------------------------- +import os.path as osp +import pickle +import shutil +import tempfile +import time + +import mmcv +import torch +import torch.distributed as dist +from mmcv.image import tensor2imgs +from mmcv.runner import get_dist_info + +from mmdet.core import encode_mask_results + + +import mmcv +import numpy as np +import pycocotools.mask as mask_util + + +def custom_encode_mask_results(mask_results): + """Encode bitmap mask to RLE code. Semantic Masks only + Args: + mask_results (list | tuple[list]): bitmap mask results. + In mask scoring rcnn, mask_results is a tuple of (segm_results, + segm_cls_score). + Returns: + list | tuple: RLE encoded mask. + """ + cls_segms = mask_results + num_classes = len(cls_segms) + encoded_mask_results = [] + for i in range(len(cls_segms)): + encoded_mask_results.append( + mask_util.encode( + np.array( + cls_segms[i][:, :, np.newaxis], order="F", dtype="uint8" + ) + )[0] + ) # encoded with RLE + return [encoded_mask_results] + + +def custom_multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): + """Test model with multiple gpus. + This method tests model with multiple gpus and collects the results + under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' + it encodes results to gpu tensors and use gpu communication for results + collection. On cpu mode it saves the results on different gpus to 'tmpdir' + and collects them by the rank 0 worker. + Args: + model (nn.Module): Model to be tested. + data_loader (nn.Dataloader): Pytorch data loader. + tmpdir (str): Path of directory to save the temporary results from + different gpus under cpu mode. + gpu_collect (bool): Option to use either gpu or cpu to collect results. + Returns: + list: The prediction results. + """ + model.eval() + bbox_results = [] + mask_results = [] + dataset = data_loader.dataset + rank, world_size = get_dist_info() + if rank == 0: + prog_bar = mmcv.ProgressBar(len(dataset)) + time.sleep(2) # This line can prevent deadlock problem in some cases. + have_mask = False + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + # encode mask results + if isinstance(result, dict): + if "bbox_results" in result.keys(): + bbox_result = result["bbox_results"] + batch_size = len(result["bbox_results"]) + bbox_results.extend(bbox_result) + if ( + "mask_results" in result.keys() + and result["mask_results"] is not None + ): + mask_result = custom_encode_mask_results( + result["mask_results"] + ) + mask_results.extend(mask_result) + have_mask = True + else: + batch_size = len(result) + bbox_results.extend(result) + + if rank == 0: + for _ in range(batch_size * world_size): + prog_bar.update() + + # collect results from all ranks + if gpu_collect: + bbox_results = collect_results_gpu(bbox_results, len(dataset)) + if have_mask: + mask_results = collect_results_gpu(mask_results, len(dataset)) + else: + mask_results = None + else: + bbox_results = collect_results_cpu(bbox_results, len(dataset), tmpdir) + tmpdir = tmpdir + "_mask" if tmpdir is not None else None + if have_mask: + mask_results = collect_results_cpu( + mask_results, len(dataset), tmpdir + ) + else: + mask_results = None + + if mask_results is None: + return bbox_results + return {"bbox_results": bbox_results, "mask_results": mask_results} + + +def collect_results_cpu(result_part, size, tmpdir=None): + rank, world_size = get_dist_info() + # create a tmp dir if it is not specified + if tmpdir is None: + MAX_LEN = 512 + # 32 is whitespace + dir_tensor = torch.full( + (MAX_LEN,), 32, dtype=torch.uint8, device="cuda" + ) + if rank == 0: + mmcv.mkdir_or_exist(".dist_test") + tmpdir = tempfile.mkdtemp(dir=".dist_test") + tmpdir = torch.tensor( + bytearray(tmpdir.encode()), dtype=torch.uint8, device="cuda" + ) + dir_tensor[: len(tmpdir)] = tmpdir + dist.broadcast(dir_tensor, 0) + tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() + else: + mmcv.mkdir_or_exist(tmpdir) + # dump the part result to the dir + mmcv.dump(result_part, osp.join(tmpdir, f"part_{rank}.pkl")) + dist.barrier() + # collect all parts + if rank != 0: + return None + else: + # load results of all parts from tmp dir + part_list = [] + for i in range(world_size): + part_file = osp.join(tmpdir, f"part_{i}.pkl") + part_list.append(mmcv.load(part_file)) + # sort the results + ordered_results = [] + """ + bacause we change the sample of the evaluation stage to make sure that + each gpu will handle continuous sample, + """ + # for res in zip(*part_list): + for res in part_list: + ordered_results.extend(list(res)) + # the dataloader may pad some samples + ordered_results = ordered_results[:size] + # remove tmp dir + shutil.rmtree(tmpdir) + return ordered_results + + +def collect_results_gpu(result_part, size): + collect_results_cpu(result_part, size) diff --git a/projects/mmdet3d_plugin/apis/train.py b/projects/mmdet3d_plugin/apis/train.py new file mode 100644 index 0000000..34cfa3e --- /dev/null +++ b/projects/mmdet3d_plugin/apis/train.py @@ -0,0 +1,62 @@ +# --------------------------------------------- +# Copyright (c) OpenMMLab. All rights reserved. +# --------------------------------------------- +# Modified by Zhiqi Li +# --------------------------------------------- + +from .mmdet_train import custom_train_detector +# from mmseg.apis import train_segmentor +from mmdet.apis import train_detector + + +def custom_train_model( + model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None, +): + """A function wrapper for launching model training according to cfg. + + Because we need different eval_hook in runner. Should be deprecated in the + future. + """ + if cfg.model.type in ["EncoderDecoder3D"]: + assert False + else: + custom_train_detector( + model, + dataset, + cfg, + distributed=distributed, + validate=validate, + timestamp=timestamp, + meta=meta, + ) + + +def train_model( + model, + dataset, + cfg, + distributed=False, + validate=False, + timestamp=None, + meta=None, +): + """A function wrapper for launching model training according to cfg. + + Because we need different eval_hook in runner. Should be deprecated in the + future. + """ + train_detector( + model, + dataset, + cfg, + distributed=distributed, + validate=validate, + timestamp=timestamp, + meta=meta, + ) diff --git a/projects/mmdet3d_plugin/core/box3d.py b/projects/mmdet3d_plugin/core/box3d.py new file mode 100644 index 0000000..93447e3 --- /dev/null +++ b/projects/mmdet3d_plugin/core/box3d.py @@ -0,0 +1,3 @@ +X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11)) # undecoded +CNS, YNS = 0, 1 # centerness and yawness indices in quality +YAW = 6 # decoded diff --git a/projects/mmdet3d_plugin/core/evaluation/__init__.py b/projects/mmdet3d_plugin/core/evaluation/__init__.py new file mode 100644 index 0000000..d92421c --- /dev/null +++ b/projects/mmdet3d_plugin/core/evaluation/__init__.py @@ -0,0 +1 @@ +from .eval_hooks import CustomDistEvalHook \ No newline at end of file diff --git a/projects/mmdet3d_plugin/core/evaluation/eval_hooks.py b/projects/mmdet3d_plugin/core/evaluation/eval_hooks.py new file mode 100644 index 0000000..6a33bb9 --- /dev/null +++ b/projects/mmdet3d_plugin/core/evaluation/eval_hooks.py @@ -0,0 +1,97 @@ +# Note: Considering that MMCV's EvalHook updated its interface in V1.3.16, +# in order to avoid strong version dependency, we did not directly +# inherit EvalHook but BaseDistEvalHook. + +import bisect +import os.path as osp + +import mmcv +import torch.distributed as dist +from mmcv.runner import DistEvalHook as BaseDistEvalHook +from mmcv.runner import EvalHook as BaseEvalHook +from torch.nn.modules.batchnorm import _BatchNorm +from mmdet.core.evaluation.eval_hooks import DistEvalHook + + +def _calc_dynamic_intervals(start_interval, dynamic_interval_list): + assert mmcv.is_list_of(dynamic_interval_list, tuple) + + dynamic_milestones = [0] + dynamic_milestones.extend( + [dynamic_interval[0] for dynamic_interval in dynamic_interval_list] + ) + dynamic_intervals = [start_interval] + dynamic_intervals.extend( + [dynamic_interval[1] for dynamic_interval in dynamic_interval_list] + ) + return dynamic_milestones, dynamic_intervals + + +class CustomDistEvalHook(BaseDistEvalHook): + def __init__(self, *args, dynamic_intervals=None, **kwargs): + super(CustomDistEvalHook, self).__init__(*args, **kwargs) + self.use_dynamic_intervals = dynamic_intervals is not None + if self.use_dynamic_intervals: + ( + self.dynamic_milestones, + self.dynamic_intervals, + ) = _calc_dynamic_intervals(self.interval, dynamic_intervals) + + def _decide_interval(self, runner): + if self.use_dynamic_intervals: + progress = runner.epoch if self.by_epoch else runner.iter + step = bisect.bisect(self.dynamic_milestones, (progress + 1)) + # Dynamically modify the evaluation interval + self.interval = self.dynamic_intervals[step - 1] + + def before_train_epoch(self, runner): + """Evaluate the model only at the start of training by epoch.""" + self._decide_interval(runner) + super().before_train_epoch(runner) + + def before_train_iter(self, runner): + self._decide_interval(runner) + super().before_train_iter(runner) + + def _do_evaluate(self, runner): + """perform evaluation and save ckpt.""" + # Synchronization of BatchNorm's buffer (running_mean + # and running_var) is not supported in the DDP of pytorch, + # which may cause the inconsistent performance of models in + # different ranks, so we broadcast BatchNorm's buffers + # of rank 0 to other ranks to avoid this. + if self.broadcast_bn_buffer: + model = runner.model + for name, module in model.named_modules(): + if ( + isinstance(module, _BatchNorm) + and module.track_running_stats + ): + dist.broadcast(module.running_var, 0) + dist.broadcast(module.running_mean, 0) + + if not self._should_evaluate(runner): + return + + tmpdir = self.tmpdir + if tmpdir is None: + tmpdir = osp.join(runner.work_dir, ".eval_hook") + + from projects.mmdet3d_plugin.apis.test import ( + custom_multi_gpu_test, + ) # to solve circlur import + + results = custom_multi_gpu_test( + runner.model, + self.dataloader, + tmpdir=tmpdir, + gpu_collect=self.gpu_collect, + ) + if runner.rank == 0: + print("\n") + runner.log_buffer.output["eval_iter_num"] = len(self.dataloader) + + key_score = self.evaluate(runner, results) + + if self.save_best: + self._save_ckpt(runner, key_score) diff --git a/projects/mmdet3d_plugin/datasets/__init__.py b/projects/mmdet3d_plugin/datasets/__init__.py new file mode 100644 index 0000000..ac90a58 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/__init__.py @@ -0,0 +1,9 @@ +from .nuscenes_3d_dataset import NuScenes3DDataset +from .builder import * +from .pipelines import * +from .samplers import * + +__all__ = [ + 'NuScenes3DDataset', + "custom_build_dataset", +] diff --git a/projects/mmdet3d_plugin/datasets/builder.py b/projects/mmdet3d_plugin/datasets/builder.py new file mode 100644 index 0000000..ab30f9d --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/builder.py @@ -0,0 +1,192 @@ +import copy +import platform +import random +from functools import partial + +import numpy as np +from mmcv.parallel import collate +from mmcv.runner import get_dist_info +from mmcv.utils import Registry, build_from_cfg +from torch.utils.data import DataLoader + +from mmdet.datasets.samplers import GroupSampler +from projects.mmdet3d_plugin.datasets.samplers import ( + GroupInBatchSampler, + DistributedGroupSampler, + DistributedSampler, + build_sampler +) + + +def build_dataloader( + dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=1, + dist=True, + shuffle=True, + seed=None, + shuffler_sampler=None, + nonshuffler_sampler=None, + runner_type="EpochBasedRunner", + **kwargs +): + """Build PyTorch DataLoader. + In distributed training, each GPU/process has a dataloader. + In non-distributed training, there is only one dataloader for all GPUs. + Args: + dataset (Dataset): A PyTorch dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e., + batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data loading + for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed training. + dist (bool): Distributed training/test or not. Default: True. + shuffle (bool): Whether to shuffle the data at every epoch. + Default: True. + kwargs: any keyword argument to be used to initialize DataLoader + Returns: + DataLoader: A PyTorch dataloader. + """ + rank, world_size = get_dist_info() + batch_sampler = None + if runner_type == 'IterBasedRunner': + print("Use GroupInBatchSampler !!!") + batch_sampler = GroupInBatchSampler( + dataset, + samples_per_gpu, + world_size, + rank, + seed=seed, + ) + batch_size = 1 + sampler = None + num_workers = workers_per_gpu + elif dist: + # DistributedGroupSampler will definitely shuffle the data to satisfy + # that images on each GPU are in the same group + if shuffle: + print("Use DistributedGroupSampler !!!") + sampler = build_sampler( + shuffler_sampler + if shuffler_sampler is not None + else dict(type="DistributedGroupSampler"), + dict( + dataset=dataset, + samples_per_gpu=samples_per_gpu, + num_replicas=world_size, + rank=rank, + seed=seed, + ), + ) + else: + sampler = build_sampler( + nonshuffler_sampler + if nonshuffler_sampler is not None + else dict(type="DistributedSampler"), + dict( + dataset=dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=seed, + ), + ) + + batch_size = samples_per_gpu + num_workers = workers_per_gpu + else: + # assert False, 'not support in bevformer' + print("WARNING!!!!, Only can be used for obtain inference speed!!!!") + sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None + batch_size = num_gpus * samples_per_gpu + num_workers = num_gpus * workers_per_gpu + + init_fn = ( + partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) + if seed is not None + else None + ) + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), + pin_memory=False, + worker_init_fn=init_fn, + **kwargs + ) + + return data_loader + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + + +# Copyright (c) OpenMMLab. All rights reserved. +import platform +from mmcv.utils import Registry, build_from_cfg + +from mmdet.datasets import DATASETS +from mmdet.datasets.builder import _concat_dataset + +if platform.system() != "Windows": + # https://github.com/pytorch/pytorch/issues/973 + import resource + + rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) + base_soft_limit = rlimit[0] + hard_limit = rlimit[1] + soft_limit = min(max(4096, base_soft_limit), hard_limit) + resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) + +OBJECTSAMPLERS = Registry("Object sampler") + + +def custom_build_dataset(cfg, default_args=None): + try: + from mmdet3d.datasets.dataset_wrappers import CBGSDataset + except: + CBGSDataset = None + from mmdet.datasets.dataset_wrappers import ( + ClassBalancedDataset, + ConcatDataset, + RepeatDataset, + ) + + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset( + [custom_build_dataset(c, default_args) for c in cfg] + ) + elif cfg["type"] == "ConcatDataset": + dataset = ConcatDataset( + [custom_build_dataset(c, default_args) for c in cfg["datasets"]], + cfg.get("separate_eval", True), + ) + elif cfg["type"] == "RepeatDataset": + dataset = RepeatDataset( + custom_build_dataset(cfg["dataset"], default_args), cfg["times"] + ) + elif cfg["type"] == "ClassBalancedDataset": + dataset = ClassBalancedDataset( + custom_build_dataset(cfg["dataset"], default_args), + cfg["oversample_thr"], + ) + elif cfg["type"] == "CBGSDataset": + dataset = CBGSDataset( + custom_build_dataset(cfg["dataset"], default_args) + ) + elif isinstance(cfg.get("ann_file"), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) + else: + dataset = build_from_cfg(cfg, DATASETS, default_args) + + return dataset diff --git a/projects/mmdet3d_plugin/datasets/evaluation/__init__.py b/projects/mmdet3d_plugin/datasets/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/projects/mmdet3d_plugin/datasets/evaluation/det/__init__.py b/projects/mmdet3d_plugin/datasets/evaluation/det/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/projects/mmdet3d_plugin/datasets/evaluation/det/occluded_det_eval.py b/projects/mmdet3d_plugin/datasets/evaluation/det/occluded_det_eval.py new file mode 100644 index 0000000..0c47320 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/evaluation/det/occluded_det_eval.py @@ -0,0 +1,521 @@ +import numpy as np +from collections import Counter +from typing import Callable, Dict, Optional, Tuple + +from nuscenes.eval.common.data_classes import EvalBoxes +from nuscenes.eval.common.loaders import load_gt, add_center_dist +from nuscenes.eval.common.utils import ( + center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean, +) +from nuscenes.eval.detection.data_classes import DetectionBox, DetectionMetricData +from nuscenes.eval.detection.evaluate import NuScenesEval + + +# --------------------------------------------------------------------------- +# Adaptive matching threshold coefficients (UniTraj class mapping) +# Formula: d = dist_th + alpha*t + beta*v*t + gamma*a*t^2 +# --------------------------------------------------------------------------- + +ADAPTIVE_COEFFS = { + 'vehicle': (0.0568, 0.1962, 0.2133), + 'cyclist': (0.1023, 0.1861, 0.2266), + 'pedestrian': (0.2641, 0.1457, 0.1774), +} + +NUSCENES_TO_UNITRAJ = { + 'car': 'vehicle', + 'truck': 'vehicle', + 'construction_vehicle': 'vehicle', + 'bus': 'vehicle', + 'trailer': 'vehicle', + 'motorcycle': 'cyclist', + 'bicycle': 'cyclist', + 'pedestrian': 'pedestrian', + # barrier, traffic_cone: no adaptive threshold +} + + +# --------------------------------------------------------------------------- +# Adaptive threshold helpers +# --------------------------------------------------------------------------- + +def _ann_speed(nusc, ann: dict) -> float: + """Speed (m/s) of an annotation estimated from position difference to prev.""" + if ann['prev'] == '': + return 0.0 + prev_ann = nusc.get('sample_annotation', ann['prev']) + curr_ts = nusc.get('sample', ann['sample_token'])['timestamp'] + prev_ts = nusc.get('sample', prev_ann['sample_token'])['timestamp'] + dt = (curr_ts - prev_ts) * 1e-6 + if dt <= 0: + return 0.0 + dx = ann['translation'][0] - prev_ann['translation'][0] + dy = ann['translation'][1] - prev_ann['translation'][1] + return np.sqrt(dx ** 2 + dy ** 2) / dt + + +def _occ_metadata(nusc, ann_token: str) -> Tuple[float, float, float]: + """Return (t, v, a) for an occluded annotation. + + t : occlusion duration in seconds (≥ 0.5 s since current frame is occluded) + v : speed at the last visible annotation (m/s) + a : acceleration magnitude at the last visible annotation (m/s²) + """ + ann = nusc.get('sample_annotation', ann_token) + + # Walk prev-links counting consecutive occluded frames and finding the + # last visible annotation and the one before it. + t_frames = 1 # current frame counts as 1 occluded frame + last_vis = None + prev_of_last_vis = None + + prev_token = ann['prev'] + while prev_token != '': + prev_ann = nusc.get('sample_annotation', prev_token) + if prev_ann['num_lidar_pts'] > 0: + last_vis = prev_ann + if prev_ann['prev'] != '': + prev_of_last_vis = nusc.get('sample_annotation', prev_ann['prev']) + break + t_frames += 1 + prev_token = prev_ann['prev'] + + t = t_frames * 0.5 # NuScenes is 2 Hz → 0.5 s per frame + + v = _ann_speed(nusc, last_vis) if last_vis is not None else 0.0 + + a = 0.0 + if last_vis is not None and prev_of_last_vis is not None: + v_last = _ann_speed(nusc, last_vis) + v_prev = _ann_speed(nusc, prev_of_last_vis) + curr_ts = nusc.get('sample', last_vis['sample_token'])['timestamp'] + prev_ts = nusc.get('sample', prev_of_last_vis['sample_token'])['timestamp'] + dt = (curr_ts - prev_ts) * 1e-6 + if dt > 0: + a = abs(v_last - v_prev) / dt + + return t, v, a + + +def _adaptive_dist_th( + base_dist_th: float, + class_name: str, + t: float, + v: float, + a: float, +) -> float: + """d = dist_th + alpha*t + beta*v*t + gamma*a*t^2""" + unitraj_cls = NUSCENES_TO_UNITRAJ.get(class_name) + if unitraj_cls is None: + return base_dist_th # static class (barrier, traffic_cone) + alpha, beta, gamma = ADAPTIVE_COEFFS[unitraj_cls] + return base_dist_th + alpha * t + beta * v * t + gamma * a * t ** 2 + + +# --------------------------------------------------------------------------- +# Custom accumulate with three-outcome matching +# --------------------------------------------------------------------------- + +def accumulate_with_ignore( + gt_boxes: EvalBoxes, + pred_boxes: EvalBoxes, + ignore_boxes: EvalBoxes, + class_name: str, + dist_fcn: Callable, + dist_th: float, + per_gt_dist_ths: Optional[Dict[Tuple[str, int], float]] = None, + verbose: bool = False, +) -> DetectionMetricData: + """AP accumulation with a three-outcome matching rule. + + For each prediction (processed in descending confidence order): + + 1. **TP** – prediction matches an unmatched visible GT box within its + effective distance threshold (adaptive if per_gt_dist_ths provided, + otherwise the fixed dist_th). + 2. **Ignored** – prediction does not match visible GT, but matches an + ignore box within dist_th. The prediction is excluded from both the + numerator and denominator of the precision-recall curve. + 3. **FP** – prediction matches neither. + + Parameters + ---------- + gt_boxes: Visible GT boxes used for scoring (TP/FP/recall denominator). + pred_boxes: All model predictions. + ignore_boxes: GT boxes that neutralise unmatched preds. + class_name: Detection class to evaluate. + dist_fcn: BEV distance function. + dist_th: Base match / ignore distance threshold in metres. + per_gt_dist_ths: Optional dict mapping (sample_token, gt_idx) → adaptive + threshold for TP matching. Ignore matching always uses + the fixed dist_th. + """ + npos = len([1 for b in gt_boxes.all if b.detection_name == class_name]) + if verbose: + print(f'Found {npos} GT of class {class_name} across ' + f'{len(gt_boxes.sample_tokens)} samples.') + + if npos == 0: + return DetectionMetricData.no_predictions() + + ignore_by_token = { + t: [b for b in ignore_boxes[t] if b.detection_name == class_name] + for t in ignore_boxes.sample_tokens + } + + pred_boxes_list = [b for b in pred_boxes.all if b.detection_name == class_name] + pred_confs = [b.detection_score for b in pred_boxes_list] + sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(pred_confs))][::-1] + + tp = [] + fp = [] + conf = [] + match_data = { + 'trans_err': [], 'vel_err': [], 'scale_err': [], + 'orient_err': [], 'attr_err': [], 'conf': [], + } + + taken = set() + + for ind in sortind: + pred_box = pred_boxes_list[ind] + + # --- Step 1: find nearest unmatched GT --- + min_dist = np.inf + match_gt_idx = None + for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]): + if gt_box.detection_name != class_name: + continue + if (pred_box.sample_token, gt_idx) in taken: + continue + d = dist_fcn(gt_box, pred_box) + if d < min_dist: + min_dist = d + match_gt_idx = gt_idx + + # Resolve effective threshold for the nearest GT box. + if match_gt_idx is not None and per_gt_dist_ths is not None: + eff_dist_th = per_gt_dist_ths.get( + (pred_box.sample_token, match_gt_idx), dist_th + ) + else: + eff_dist_th = dist_th + + if min_dist < eff_dist_th: + taken.add((pred_box.sample_token, match_gt_idx)) + tp.append(1) + fp.append(0) + conf.append(pred_box.detection_score) + + gt_match = gt_boxes[pred_box.sample_token][match_gt_idx] + match_data['trans_err'].append(center_distance(gt_match, pred_box)) + match_data['vel_err'].append(velocity_l2(gt_match, pred_box)) + match_data['scale_err'].append(1 - scale_iou(gt_match, pred_box)) + period = np.pi if class_name == 'barrier' else 2 * np.pi + match_data['orient_err'].append(yaw_diff(gt_match, pred_box, period=period)) + match_data['attr_err'].append(1 - attr_acc(gt_match, pred_box)) + match_data['conf'].append(pred_box.detection_score) + + else: + # Step 2: check ignore boxes (always with fixed dist_th). + ignores = ignore_by_token.get(pred_box.sample_token, []) + is_ignored = any(dist_fcn(ign, pred_box) < dist_th for ign in ignores) + + if is_ignored: + continue + + tp.append(0) + fp.append(1) + conf.append(pred_box.detection_score) + + if len(match_data['trans_err']) == 0: + return DetectionMetricData.no_predictions() + + tp = np.cumsum(tp).astype(float) + fp = np.cumsum(fp).astype(float) + conf = np.array(conf) + + prec = tp / (fp + tp) + rec = tp / float(npos) + + rec_interp = np.linspace(0, 1, DetectionMetricData.nelem) + prec = np.interp(rec_interp, rec, prec, right=0) + conf = np.interp(rec_interp, rec, conf, right=0) + rec = rec_interp + + for key in match_data: + if key == 'conf': + continue + tmp = cummean(np.array(match_data[key])) + match_data[key] = np.interp( + conf[::-1], match_data['conf'][::-1], tmp[::-1] + )[::-1] + + return DetectionMetricData( + recall=rec, precision=prec, confidence=conf, + trans_err=match_data['trans_err'], vel_err=match_data['vel_err'], + scale_err=match_data['scale_err'], orient_err=match_data['orient_err'], + attr_err=match_data['attr_err'], + ) + + +# --------------------------------------------------------------------------- +# Base evaluator +# --------------------------------------------------------------------------- + +class _IgnoreAwareNuScenesEval(NuScenesEval): + """Base class for ignore-aware detection evaluators.""" + + def _filter_occluded_boxes(self, gt_boxes: EvalBoxes) -> EvalBoxes: + filtered = EvalBoxes() + for sample_token in gt_boxes.sample_tokens: + boxes = [ + box for box in gt_boxes[sample_token] + if box.num_pts == 0 + and box.detection_name in self.cfg.class_range + and box.ego_dist < self.cfg.class_range[box.detection_name] + ] + filtered.add_boxes(sample_token, boxes) + return filtered + + def _filter_visible_boxes(self, gt_boxes: EvalBoxes) -> EvalBoxes: + filtered = EvalBoxes() + for sample_token in gt_boxes.sample_tokens: + boxes = [ + box for box in gt_boxes[sample_token] + if box.num_pts > 0 + and box.detection_name in self.cfg.class_range + and box.ego_dist < self.cfg.class_range[box.detection_name] + ] + filtered.add_boxes(sample_token, boxes) + return filtered + + def evaluate(self): + import time + from nuscenes.eval.detection.algo import calc_ap, calc_tp + from nuscenes.eval.detection.data_classes import ( + DetectionMetrics, DetectionMetricDataList, + ) + from nuscenes.eval.detection.constants import TP_METRICS + + start_time = time.time() + if self.verbose: + print('Accumulating metric data...') + + metric_data_list = DetectionMetricDataList() + for class_name in self.cfg.class_names: + for dist_th in self.cfg.dist_ths: + md = accumulate_with_ignore( + self.gt_boxes, + self.pred_boxes, + self._ignore_boxes, + class_name, + self.cfg.dist_fcn_callable, + dist_th, + ) + metric_data_list.set(class_name, dist_th, md) + + if self.verbose: + print('Calculating metrics...') + + metrics = DetectionMetrics(self.cfg) + for class_name in self.cfg.class_names: + for dist_th in self.cfg.dist_ths: + metric_data = metric_data_list[(class_name, dist_th)] + ap = calc_ap(metric_data, self.cfg.min_recall, self.cfg.min_precision) + metrics.add_label_ap(class_name, dist_th, ap) + + for metric_name in TP_METRICS: + metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)] + if class_name == 'traffic_cone' and metric_name in ('attr_err', 'vel_err', 'orient_err'): + tp = np.nan + elif class_name == 'barrier' and metric_name in ('attr_err', 'vel_err'): + tp = np.nan + else: + tp = calc_tp(metric_data, self.cfg.min_recall, metric_name) + metrics.add_label_tp(class_name, metric_name, tp) + + metrics.add_runtime(time.time() - start_time) + return metrics, metric_data_list + + +# --------------------------------------------------------------------------- +# Concrete evaluators +# --------------------------------------------------------------------------- + +class OccludedDetectionEval(_IgnoreAwareNuScenesEval): + """NuScenes detection evaluator restricted to occluded objects. + + Uses an adaptive matching threshold d = dist_th + alpha*t + beta*v*t + + gamma*a*t^2 where t is occlusion duration, v is speed and a is + acceleration at the last visible annotation. Coefficients are + class-specific (Vehicle / Cyclist / Pedestrian via UniTraj mapping). + Static classes (barrier, traffic_cone) retain the fixed dist_th. + """ + + def __init__(self, nusc, config, result_path, eval_set, output_dir, verbose): + super().__init__(nusc, config, result_path, eval_set, output_dir, verbose) + + all_gt = load_gt(nusc, eval_set, DetectionBox, verbose=verbose) + all_gt = add_center_dist(nusc, all_gt) + + self.gt_boxes = self._filter_occluded_boxes(all_gt) + self.sample_tokens = self.gt_boxes.sample_tokens + self._ignore_boxes = self._filter_visible_boxes(all_gt) + + occ_total = sum(len(self.gt_boxes[t]) for t in self.gt_boxes.sample_tokens) + class_counts = Counter( + box.detection_name + for t in self.gt_boxes.sample_tokens + for box in self.gt_boxes[t] + ) + vis_total = sum(len(self._ignore_boxes[t]) for t in self._ignore_boxes.sample_tokens) + print(f'[Occluded Det] GT occluded boxes: {occ_total} | {dict(class_counts)}') + print(f'[Occluded Det] Ignore (visible) boxes: {vis_total}') + + # Pre-compute adaptive thresholds for each (base_dist_th). + self._adaptive_dist_ths = self._build_adaptive_dist_ths(nusc) + + def _build_adaptive_dist_ths( + self, nusc + ) -> Dict[float, Dict[Tuple[str, int], float]]: + """Build per-GT-box adaptive thresholds for every base dist_th. + + Returns + ------- + dict mapping base_dist_th → {(sample_token, gt_idx): adaptive_dist_th} + """ + # Build sample_token → {rounded_translation: ann_token} for fast lookup. + sample_ann_map: Dict[str, Dict[tuple, str]] = {} + for sample_token in self.gt_boxes.sample_tokens: + sample = nusc.get('sample', sample_token) + pos_to_tok = {} + for ann_token in sample['anns']: + ann = nusc.get('sample_annotation', ann_token) + key = (round(ann['translation'][0], 2), + round(ann['translation'][1], 2)) + pos_to_tok[key] = ann_token + sample_ann_map[sample_token] = pos_to_tok + + # Compute (t, v, a) and cache adaptive threshold per box per dist_th. + result: Dict[float, Dict[Tuple[str, int], float]] = { + d: {} for d in self.cfg.dist_ths + } + + for sample_token in self.gt_boxes.sample_tokens: + pos_to_tok = sample_ann_map[sample_token] + for gt_idx, box in enumerate(self.gt_boxes[sample_token]): + key_pos = (round(box.translation[0], 2), + round(box.translation[1], 2)) + ann_token = pos_to_tok.get(key_pos) + if ann_token is None: + continue # fallback: keep base dist_th (entry absent → default) + + t, v, a = _occ_metadata(nusc, ann_token) + for base_dist_th in self.cfg.dist_ths: + adaptive = _adaptive_dist_th(base_dist_th, box.detection_name, + t, v, a) + result[base_dist_th][(sample_token, gt_idx)] = adaptive + + return result + + def evaluate(self): + """Override to pass adaptive per-GT thresholds to accumulate_with_ignore.""" + import time + from nuscenes.eval.detection.algo import calc_ap, calc_tp + from nuscenes.eval.detection.data_classes import ( + DetectionMetrics, DetectionMetricDataList, + ) + from nuscenes.eval.detection.constants import TP_METRICS + + start_time = time.time() + if self.verbose: + print('Accumulating metric data (adaptive thresholds)...') + + metric_data_list = DetectionMetricDataList() + for class_name in self.cfg.class_names: + for dist_th in self.cfg.dist_ths: + md = accumulate_with_ignore( + self.gt_boxes, + self.pred_boxes, + self._ignore_boxes, + class_name, + self.cfg.dist_fcn_callable, + dist_th, + per_gt_dist_ths=self._adaptive_dist_ths.get(dist_th), + ) + metric_data_list.set(class_name, dist_th, md) + + if self.verbose: + print('Calculating metrics...') + + metrics = DetectionMetrics(self.cfg) + for class_name in self.cfg.class_names: + for dist_th in self.cfg.dist_ths: + metric_data = metric_data_list[(class_name, dist_th)] + ap = calc_ap(metric_data, self.cfg.min_recall, self.cfg.min_precision) + metrics.add_label_ap(class_name, dist_th, ap) + + for metric_name in TP_METRICS: + metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)] + if class_name == 'traffic_cone' and metric_name in ('attr_err', 'vel_err', 'orient_err'): + tp = np.nan + elif class_name == 'barrier' and metric_name in ('attr_err', 'vel_err'): + tp = np.nan + else: + tp = calc_tp(metric_data, self.cfg.min_recall, metric_name) + metrics.add_label_tp(class_name, metric_name, tp) + + metrics.add_runtime(time.time() - start_time) + return metrics, metric_data_list + + +class VisibleDetectionEval(_IgnoreAwareNuScenesEval): + """NuScenes detection evaluator on visible objects (num_lidar_pts >= 1).""" + + def __init__(self, nusc, config, result_path, eval_set, output_dir, verbose): + super().__init__(nusc, config, result_path, eval_set, output_dir, verbose) + + all_gt = load_gt(nusc, eval_set, DetectionBox, verbose=verbose) + all_gt = add_center_dist(nusc, all_gt) + self._ignore_boxes = self._filter_occluded_boxes(all_gt) + + vis_total = sum(len(self.gt_boxes[t]) for t in self.gt_boxes.sample_tokens) + class_counts = Counter( + box.detection_name + for t in self.gt_boxes.sample_tokens + for box in self.gt_boxes[t] + ) + occ_total = sum(len(self._ignore_boxes[t]) for t in self._ignore_boxes.sample_tokens) + print(f'[Visible Det] GT visible boxes: {vis_total} | {dict(class_counts)}') + print(f'[Visible Det] Ignore (occluded) boxes: {occ_total}') + + +class AllDetectionEval(NuScenesEval): + """NuScenes detection evaluator on all objects (visible + occluded).""" + + def __init__(self, nusc, config, result_path, eval_set, output_dir, verbose): + super().__init__(nusc, config, result_path, eval_set, output_dir, verbose) + + self.gt_boxes = load_gt(nusc, eval_set, DetectionBox, verbose=verbose) + self.gt_boxes = add_center_dist(nusc, self.gt_boxes) + self.gt_boxes = self._filter_all_gt(self.gt_boxes) + self.sample_tokens = self.gt_boxes.sample_tokens + + def _filter_all_gt(self, gt_boxes: EvalBoxes) -> EvalBoxes: + filtered = EvalBoxes() + for sample_token in gt_boxes.sample_tokens: + boxes = [ + box for box in gt_boxes[sample_token] + if box.detection_name in self.cfg.class_range + and box.ego_dist < self.cfg.class_range[box.detection_name] + ] + filtered.add_boxes(sample_token, boxes) + total = sum(len(filtered[t]) for t in filtered.sample_tokens) + class_counts = Counter( + box.detection_name + for t in filtered.sample_tokens + for box in filtered[t] + ) + print(f'[All Det] GT all boxes: {total} | {dict(class_counts)}') + return filtered diff --git a/projects/mmdet3d_plugin/datasets/evaluation/map/AP.py b/projects/mmdet3d_plugin/datasets/evaluation/map/AP.py new file mode 100644 index 0000000..0be3480 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/evaluation/map/AP.py @@ -0,0 +1,136 @@ +import numpy as np +from .distance import chamfer_distance, frechet_distance, chamfer_distance_batch +from typing import List, Tuple, Union +from numpy.typing import NDArray + +def average_precision(recalls, precisions, mode='area'): + """Calculate average precision. + + Args: + recalls (ndarray): shape (num_dets, ) + precisions (ndarray): shape (num_dets, ) + mode (str): 'area' or '11points', 'area' means calculating the area + under precision-recall curve, '11points' means calculating + the average precision of recalls at [0, 0.1, ..., 1] + + Returns: + float: calculated average precision + """ + + recalls = recalls[np.newaxis, :] + precisions = precisions[np.newaxis, :] + + assert recalls.shape == precisions.shape and recalls.ndim == 2 + num_scales = recalls.shape[0] + ap = 0. + + if mode == 'area': + zeros = np.zeros((num_scales, 1), dtype=recalls.dtype) + ones = np.ones((num_scales, 1), dtype=recalls.dtype) + mrec = np.hstack((zeros, recalls, ones)) + mpre = np.hstack((zeros, precisions, zeros)) + for i in range(mpre.shape[1] - 1, 0, -1): + mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i]) + + ind = np.where(mrec[0, 1:] != mrec[0, :-1])[0] + ap = np.sum( + (mrec[0, ind + 1] - mrec[0, ind]) * mpre[0, ind + 1]) + + elif mode == '11points': + for thr in np.arange(0, 1 + 1e-3, 0.1): + precs = precisions[0, recalls[i, :] >= thr] + prec = precs.max() if precs.size > 0 else 0 + ap += prec + ap /= 11 + else: + raise ValueError( + 'Unrecognized mode, only "area" and "11points" are supported') + + return ap + +def instance_match(pred_lines: NDArray, + scores: NDArray, + gt_lines: NDArray, + thresholds: Union[Tuple, List], + metric: str='chamfer') -> List: + """Compute whether detected lines are true positive or false positive. + + Args: + pred_lines (array): Detected lines of a sample, of shape (M, INTERP_NUM, 2 or 3). + scores (array): Confidence score of each line, of shape (M, ). + gt_lines (array): GT lines of a sample, of shape (N, INTERP_NUM, 2 or 3). + thresholds (list of tuple): List of thresholds. + metric (str): Distance function for lines matching. Default: 'chamfer'. + + Returns: + list_of_tp_fp (list): tp-fp matching result at all thresholds + """ + + if metric == 'chamfer': + distance_fn = chamfer_distance + + elif metric == 'frechet': + distance_fn = frechet_distance + + else: + raise ValueError(f'unknown distance function {metric}') + + num_preds = pred_lines.shape[0] + num_gts = gt_lines.shape[0] + + # tp and fp + tp_fp_list = [] + tp = np.zeros((num_preds), dtype=np.float32) + fp = np.zeros((num_preds), dtype=np.float32) + + # if there is no gt lines in this sample, then all pred lines are false positives + if num_gts == 0: + fp[...] = 1 + for thr in thresholds: + tp_fp_list.append((tp.copy(), fp.copy())) + return tp_fp_list + + if num_preds == 0: + for thr in thresholds: + tp_fp_list.append((tp.copy(), fp.copy())) + return tp_fp_list + + assert pred_lines.shape[1] == gt_lines.shape[1], \ + "sample points num should be the same" + + # distance matrix: M x N + matrix = np.zeros((num_preds, num_gts)) + + # for i in range(num_preds): + # for j in range(num_gts): + # matrix[i, j] = distance_fn(pred_lines[i], gt_lines[j]) + + matrix = chamfer_distance_batch(pred_lines, gt_lines) + # for each det, the min distance with all gts + matrix_min = matrix.min(axis=1) + + # for each det, which gt is the closest to it + matrix_argmin = matrix.argmin(axis=1) + # sort all dets in descending order by scores + sort_inds = np.argsort(-scores) + + # match under different thresholds + for thr in thresholds: + tp = np.zeros((num_preds), dtype=np.float32) + fp = np.zeros((num_preds), dtype=np.float32) + + gt_covered = np.zeros(num_gts, dtype=bool) + for i in sort_inds: + if matrix_min[i] <= thr: + matched_gt = matrix_argmin[i] + if not gt_covered[matched_gt]: + gt_covered[matched_gt] = True + tp[i] = 1 + else: + fp[i] = 1 + else: + fp[i] = 1 + + tp_fp_list.append((tp, fp)) + + return tp_fp_list \ No newline at end of file diff --git a/projects/mmdet3d_plugin/datasets/evaluation/map/distance.py b/projects/mmdet3d_plugin/datasets/evaluation/map/distance.py new file mode 100644 index 0000000..0152755 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/evaluation/map/distance.py @@ -0,0 +1,67 @@ +from scipy.spatial import distance +from numpy.typing import NDArray +import torch + +def chamfer_distance(line1: NDArray, line2: NDArray) -> float: + ''' Calculate chamfer distance between two lines. Make sure the + lines are interpolated. + + Args: + line1 (array): coordinates of line1 + line2 (array): coordinates of line2 + + Returns: + distance (float): chamfer distance + ''' + + dist_matrix = distance.cdist(line1, line2, 'euclidean') + dist12 = dist_matrix.min(-1).sum() / len(line1) + dist21 = dist_matrix.min(-2).sum() / len(line2) + + return (dist12 + dist21) / 2 + +def frechet_distance(line1: NDArray, line2: NDArray) -> float: + ''' Calculate frechet distance between two lines. Make sure the + lines are interpolated. + + Args: + line1 (array): coordinates of line1 + line2 (array): coordinates of line2 + + Returns: + distance (float): frechet distance + ''' + + raise NotImplementedError + +def chamfer_distance_batch(pred_lines, gt_lines): + ''' Calculate chamfer distance between two group of lines. Make sure the + lines are interpolated. + + Args: + pred_lines (array or tensor): shape (m, num_pts, 2 or 3) + gt_lines (array or tensor): shape (n, num_pts, 2 or 3) + + Returns: + distance (array): chamfer distance + ''' + _, num_pts, coord_dims = pred_lines.shape + + if not isinstance(pred_lines, torch.Tensor): + pred_lines = torch.tensor(pred_lines) + if not isinstance(gt_lines, torch.Tensor): + gt_lines = torch.tensor(gt_lines) + dist_mat = torch.cdist(pred_lines.view(-1, coord_dims), + gt_lines.view(-1, coord_dims), p=2) + # (num_query*num_points, num_gt*num_points) + dist_mat = torch.stack(torch.split(dist_mat, num_pts)) + # (num_query, num_points, num_gt*num_points) + dist_mat = torch.stack(torch.split(dist_mat, num_pts, dim=-1)) + # (num_gt, num_q, num_pts, num_pts) + + dist1 = dist_mat.min(-1)[0].sum(-1) + dist2 = dist_mat.min(-2)[0].sum(-1) + + dist_matrix = (dist1 + dist2).transpose(0, 1) / (2 * num_pts) + + return dist_matrix.numpy() \ No newline at end of file diff --git a/projects/mmdet3d_plugin/datasets/evaluation/map/vector_eval.py b/projects/mmdet3d_plugin/datasets/evaluation/map/vector_eval.py new file mode 100644 index 0000000..de7e83b --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/evaluation/map/vector_eval.py @@ -0,0 +1,279 @@ +import prettytable +from typing import Dict, List, Optional +from time import time +from copy import deepcopy +from multiprocessing import Pool +from logging import Logger +from functools import partial, cached_property + +import numpy as np +from numpy.typing import NDArray +from shapely.geometry import LineString + +import mmcv +from mmcv import Config +from mmdet.datasets import build_dataset, build_dataloader + +from .AP import instance_match, average_precision + +INTERP_NUM = 200 # number of points to interpolate during evaluation +THRESHOLDS = [0.5, 1.0, 1.5] # AP thresholds +N_WORKERS = 8 # num workers to parallel + +class VectorEvaluate(object): + """Evaluator for vectorized map. + + Args: + dataset_cfg (Config): dataset cfg for gt + n_workers (int): num workers to parallel + """ + + def __init__(self, dataset_cfg: Config, n_workers: int=N_WORKERS) -> None: + self.dataset = build_dataset(dataset_cfg) + self.dataloader = build_dataloader( + self.dataset, samples_per_gpu=1, workers_per_gpu=n_workers, shuffle=False, dist=False) + classes = self.dataset.MAP_CLASSES + self.cat2id = {cls: i for i, cls in enumerate(classes)} + self.id2cat = {v: k for k, v in self.cat2id.items()} + self.n_workers = n_workers + self.thresholds = [0.5, 1.0, 1.5] + + @cached_property + def gts(self) -> Dict[str, Dict[int, List[NDArray]]]: + print('collecting gts...') + gts = {} + pbar = mmcv.ProgressBar(len(self.dataloader)) + for data in self.dataloader: + token = deepcopy(data['img_metas'].data[0][0]['token']) + gt = deepcopy(data['vectors'].data[0][0]) + gts[token] = gt + pbar.update() + del data # avoid dataloader memory crash + + return gts + + def interp_fixed_num(self, + vector: NDArray, + num_pts: int) -> NDArray: + ''' Interpolate a polyline. + + Args: + vector (array): line coordinates, shape (M, 2) + num_pts (int): + + Returns: + sampled_points (array): interpolated coordinates + ''' + line = LineString(vector) + distances = np.linspace(0, line.length, num_pts) + sampled_points = np.array([list(line.interpolate(distance).coords) + for distance in distances]).squeeze() + + return sampled_points + + def interp_fixed_dist(self, + vector: NDArray, + sample_dist: float) -> NDArray: + ''' Interpolate a line at fixed interval. + + Args: + vector (LineString): vector + sample_dist (float): sample interval + + Returns: + points (array): interpolated points, shape (N, 2) + ''' + line = LineString(vector) + distances = list(np.arange(sample_dist, line.length, sample_dist)) + # make sure to sample at least two points when sample_dist > line.length + distances = [0,] + distances + [line.length,] + + sampled_points = np.array([list(line.interpolate(distance).coords) + for distance in distances]).squeeze() + + return sampled_points + + def _evaluate_single(self, + pred_vectors: List, + scores: List, + groundtruth: List, + thresholds: List, + metric: str='metric') -> Dict[int, NDArray]: + ''' Do single-frame matching for one class. + + Args: + pred_vectors (List): List[vector(ndarray) (different length)], + scores (List): List[score(float)] + groundtruth (List): List of vectors + thresholds (List): List of thresholds + + Returns: + tp_fp_score_by_thr (Dict): matching results at different thresholds + e.g. {0.5: (M, 2), 1.0: (M, 2), 1.5: (M, 2)} + ''' + pred_lines = [] + + # interpolate predictions + for vector in pred_vectors: + vector = np.array(vector) + vector_interp = self.interp_fixed_num(vector, INTERP_NUM) + pred_lines.append(vector_interp) + if pred_lines: + pred_lines = np.stack(pred_lines) + else: + pred_lines = np.zeros((0, INTERP_NUM, 2)) + + # interpolate groundtruth + gt_lines = [] + for vector in groundtruth: + vector_interp = self.interp_fixed_num(vector, INTERP_NUM) + gt_lines.append(vector_interp) + if gt_lines: + gt_lines = np.stack(gt_lines) + else: + gt_lines = np.zeros((0, INTERP_NUM, 2)) + + scores = np.array(scores) + tp_fp_list = instance_match(pred_lines, scores, gt_lines, thresholds, metric) # (M, 2) + tp_fp_score_by_thr = {} + for i, thr in enumerate(thresholds): + tp, fp = tp_fp_list[i] + tp_fp_score = np.hstack([tp[:, None], fp[:, None], scores[:, None]]) + tp_fp_score_by_thr[thr] = tp_fp_score + + return tp_fp_score_by_thr # {0.5: (M, 2), 1.0: (M, 2), 1.5: (M, 2)} + + def evaluate(self, + result_path: str, + metric: str='chamfer', + logger: Optional[Logger]=None) -> Dict[str, float]: + ''' Do evaluation for a submission file and print evalution results to `logger` if specified. + The submission will be aligned by tokens before evaluation. We use multi-worker to speed up. + + Args: + result_path (str): path to submission file + metric (str): distance metric. Default: 'chamfer' + logger (Logger): logger to print evaluation result, Default: None + + Returns: + new_result_dict (Dict): evaluation results. AP by categories. + ''' + results = mmcv.load(result_path) + results = results['results'] + + # re-group samples and gt by label + samples_by_cls = {label: [] for label in self.id2cat.keys()} + num_gts = {label: 0 for label in self.id2cat.keys()} + num_preds = {label: 0 for label in self.id2cat.keys()} + + # align by token + for token, gt in self.gts.items(): + if token in results.keys(): + pred = results[token] + else: + pred = {'vectors': [], 'scores': [], 'labels': []} + + # for every sample + vectors_by_cls = {label: [] for label in self.id2cat.keys()} + scores_by_cls = {label: [] for label in self.id2cat.keys()} + + for i in range(len(pred['labels'])): + # i-th pred line in sample + label = pred['labels'][i] + vector = pred['vectors'][i] + score = pred['scores'][i] + + vectors_by_cls[label].append(vector) + scores_by_cls[label].append(score) + + for label in self.id2cat.keys(): + new_sample = (vectors_by_cls[label], scores_by_cls[label], gt[label]) + num_gts[label] += len(gt[label]) + num_preds[label] += len(scores_by_cls[label]) + samples_by_cls[label].append(new_sample) + + result_dict = {} + + print(f'\nevaluating {len(self.id2cat)} categories...') + start = time() + if self.n_workers > 0: + pool = Pool(self.n_workers) + + sum_mAP = 0 + pbar = mmcv.ProgressBar(len(self.id2cat)) + for label in self.id2cat.keys(): + samples = samples_by_cls[label] # List[(pred_lines, scores, gts)] + result_dict[self.id2cat[label]] = { + 'num_gts': num_gts[label], + 'num_preds': num_preds[label] + } + sum_AP = 0 + + fn = partial(self._evaluate_single, thresholds=self.thresholds, metric=metric) + if self.n_workers > 0 and len(samples) > 81: + tpfp_score_list = pool.starmap(fn, samples) + else: + tpfp_score_list = [] + for sample in samples: + tpfp_score_list.append(fn(*sample)) + + for thr in self.thresholds: + tp_fp_score = [i[thr] for i in tpfp_score_list] + tp_fp_score = np.vstack(tp_fp_score) # (num_dets, 3) + sort_inds = np.argsort(-tp_fp_score[:, -1]) + + tp = tp_fp_score[sort_inds, 0] # (num_dets,) + fp = tp_fp_score[sort_inds, 1] # (num_dets,) + tp = np.cumsum(tp, axis=0) + fp = np.cumsum(fp, axis=0) + eps = np.finfo(np.float32).eps + recalls = tp / np.maximum(num_gts[label], eps) + precisions = tp / np.maximum((tp + fp), eps) + + AP = average_precision(recalls, precisions, 'area') + sum_AP += AP + result_dict[self.id2cat[label]].update({f'AP@{thr}': AP}) + + pbar.update() + + AP = sum_AP / len(self.thresholds) + sum_mAP += AP + + result_dict[self.id2cat[label]].update({f'AP': AP}) + + if self.n_workers > 0: + pool.close() + + mAP = sum_mAP / len(self.id2cat.keys()) + result_dict.update({'mAP': mAP}) + + print(f"finished in {time() - start:.2f}s") + + # print results + table = prettytable.PrettyTable(['category', 'num_preds', 'num_gts'] + + [f'AP@{thr}' for thr in self.thresholds] + ['AP']) + for label in self.id2cat.keys(): + table.add_row([ + self.id2cat[label], + result_dict[self.id2cat[label]]['num_preds'], + result_dict[self.id2cat[label]]['num_gts'], + *[round(result_dict[self.id2cat[label]][f'AP@{thr}'], 4) for thr in self.thresholds], + round(result_dict[self.id2cat[label]]['AP'], 4), + ]) + + from mmcv.utils import print_log + print_log('\n'+str(table), logger=logger) + mAP_normal = 0 + for label in self.id2cat.keys(): + for thr in self.thresholds: + mAP_normal += result_dict[self.id2cat[label]][f'AP@{thr}'] + mAP_normal = mAP_normal / 9 + + print_log(f'mAP_normal = {mAP_normal:.4f}\n', logger=logger) + # print_log(f'mAP_hard = {mAP_easy:.4f}\n', logger=logger) + + new_result_dict = {} + for name in self.cat2id: + new_result_dict[name] = result_dict[name]['AP'] + new_result_dict['mAP_normal'] = mAP_normal + return new_result_dict \ No newline at end of file diff --git a/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_eval_uniad.py b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_eval_uniad.py new file mode 100644 index 0000000..c1f27db --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_eval_uniad.py @@ -0,0 +1,417 @@ +# nuScenes dev-kit. +# Code written by Holger Caesar & Oscar Beijbom, 2018. + +import argparse +import json +import os +import random +import time +import tqdm +from typing import Tuple, Dict, Any + +import numpy as np + +from nuscenes import NuScenes +from nuscenes.eval.common.config import config_factory +from nuscenes.eval.common.data_classes import EvalBoxes +from nuscenes.eval.common.loaders import add_center_dist, filter_eval_boxes +from nuscenes.eval.detection.algo import accumulate, calc_ap, calc_tp +from nuscenes.eval.detection.constants import DETECTION_NAMES, ATTRIBUTE_NAMES, TP_METRICS +from nuscenes.eval.detection.data_classes import DetectionConfig, DetectionMetrics, DetectionBox, \ + DetectionMetricDataList, DetectionMetricData +from nuscenes.eval.detection.render import summary_plot, class_pr_curve, class_tp_curve, dist_pr_curve, visualize_sample +from nuscenes.prediction import PredictHelper, convert_local_coords_to_global +from nuscenes.utils.splits import create_splits_scenes +from nuscenes.eval.detection.utils import category_to_detection_name +from nuscenes.eval.common.utils import quaternion_yaw, Quaternion +from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean + +from .motion_utils import MotionBox, load_prediction, load_gt, accumulate +MOTION_TP_METRICS = ['min_ade_err', 'min_fde_err', 'miss_rate_err', 'top1_fde_err', 'brier_min_fde_err'] + + +class MotionEval: + """ + This is the official nuScenes detection evaluation code. + Results are written to the provided output_dir. + + nuScenes uses the following detection metrics: + - Mean Average Precision (mAP): Uses center-distance as matching criterion; averaged over distance thresholds. + - True Positive (TP) metrics: Average of translation, velocity, scale, orientation and attribute errors. + - nuScenes Detection Score (NDS): The weighted sum of the above. + + Here is an overview of the functions in this method: + - init: Loads GT annotations and predictions stored in JSON format and filters the boxes. + - run: Performs evaluation and dumps the metric data to disk. + - render: Renders various plots and dumps to disk. + + We assume that: + - Every sample_token is given in the results, although there may be not predictions for that sample. + + Please see https://www.nuscenes.org/object-detection for more details. + """ + def __init__(self, + nusc: NuScenes, + config: DetectionConfig, + result_path: str, + eval_set: str, + output_dir: str = None, + verbose: bool = True, + seconds: int = 12): + """ + Initialize a DetectionEval object. + :param nusc: A NuScenes object. + :param config: A DetectionConfig object. + :param result_path: Path of the nuScenes JSON result file. + :param eval_set: The dataset split to evaluate on, e.g. train, val or test. + :param output_dir: Folder to save plots and results to. + :param verbose: Whether to print to stdout. + """ + self.nusc = nusc + self.result_path = result_path + self.eval_set = eval_set + self.output_dir = output_dir + self.verbose = verbose + self.cfg = config + + # Check result file exists. + # assert os.path.exists(result_path), 'Error: The result file does not exist!' + + # Make dirs. + self.plot_dir = os.path.join(self.output_dir, 'plots') + if not os.path.isdir(self.output_dir): + os.makedirs(self.output_dir) + if not os.path.isdir(self.plot_dir): + os.makedirs(self.plot_dir) + + # Load data. + if verbose: + print('Initializing nuScenes detection evaluation') + self.pred_boxes, self.meta = load_prediction(self.result_path, self.cfg.max_boxes_per_sample, MotionBox, + verbose=verbose) + self.gt_boxes = load_gt(self.nusc, self.eval_set, MotionBox, verbose=verbose, seconds=seconds) + + assert set(self.pred_boxes.sample_tokens) == set(self.gt_boxes.sample_tokens), \ + "Samples in split doesn't match samples in predictions." + + # Add center distances. + self.pred_boxes = add_center_dist(nusc, self.pred_boxes) + self.gt_boxes = add_center_dist(nusc, self.gt_boxes) + + # Filter boxes (distance, points per box, etc.). + if verbose: + print('Filtering predictions') + self.pred_boxes = filter_eval_boxes(nusc, self.pred_boxes, self.cfg.class_range, verbose=verbose) + if verbose: + print('Filtering ground truth annotations') + self.gt_boxes = filter_eval_boxes(nusc, self.gt_boxes, self.cfg.class_range, verbose=verbose) + + self.sample_tokens = self.gt_boxes.sample_tokens + + def evaluate(self) -> Tuple[DetectionMetrics, DetectionMetricDataList]: + """ + Performs the actual evaluation. + :return: A tuple of high-level and the raw metric data. + """ + start_time = time.time() + self.cfg.class_names = ['car', 'pedestrian'] + self.cfg.dist_ths = [2.0] + + # ----------------------------------- + # Step 1: Accumulate metric data for all classes and distance thresholds. + # ----------------------------------- + if self.verbose: + print('Accumulating metric data...') + metric_data_list = DetectionMetricDataList() + metrics = {} + for class_name in self.cfg.class_names: + for dist_th in self.cfg.dist_ths: + md, EPA, EPA_ = accumulate(self.gt_boxes, self.pred_boxes, class_name, self.cfg.dist_fcn_callable, dist_th) + metric_data_list.set(class_name, dist_th, md) + metrics[f'{class_name}_EPA'] = EPA_ + + # ----------------------------------- + # Step 2: Calculate metrics from the data. + # ----------------------------------- + if self.verbose: + print('Calculating metrics...') + for class_name in self.cfg.class_names: + # Compute TP metrics. + for metric_name in MOTION_TP_METRICS: + metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)] + tp = calc_tp(metric_data, self.cfg.min_recall, metric_name) + metrics[f'{class_name}_{metric_name}'] = tp + + return metrics, metric_data_list + + def render(self, metrics: DetectionMetrics, md_list: DetectionMetricDataList) -> None: + """ + Renders various PR and TP curves. + :param metrics: DetectionMetrics instance. + :param md_list: DetectionMetricDataList instance. + """ + if self.verbose: + print('Rendering PR and TP curves') + + def savepath(name): + return os.path.join(self.plot_dir, name + '.pdf') + + summary_plot(md_list, metrics, min_precision=self.cfg.min_precision, min_recall=self.cfg.min_recall, + dist_th_tp=self.cfg.dist_th_tp, savepath=savepath('summary')) + + for detection_name in self.cfg.class_names: + class_pr_curve(md_list, metrics, detection_name, self.cfg.min_precision, self.cfg.min_recall, + savepath=savepath(detection_name + '_pr')) + + class_tp_curve(md_list, metrics, detection_name, self.cfg.min_recall, self.cfg.dist_th_tp, + savepath=savepath(detection_name + '_tp')) + + for dist_th in self.cfg.dist_ths: + dist_pr_curve(md_list, metrics, dist_th, self.cfg.min_precision, self.cfg.min_recall, + savepath=savepath('dist_pr_' + str(dist_th))) + + def main(self, + plot_examples: int = 0, + render_curves: bool = True) -> Dict[str, Any]: + """ + Main function that loads the evaluation code, visualizes samples, runs the evaluation and renders stat plots. + :param plot_examples: How many example visualizations to write to disk. + :param render_curves: Whether to render PR and TP curves to disk. + :return: A dict that stores the high-level metrics and meta data. + """ + if plot_examples > 0: + # Select a random but fixed subset to plot. + random.seed(42) + sample_tokens = list(self.sample_tokens) + random.shuffle(sample_tokens) + sample_tokens = sample_tokens[:plot_examples] + + # Visualize samples. + example_dir = os.path.join(self.output_dir, 'examples') + if not os.path.isdir(example_dir): + os.mkdir(example_dir) + for sample_token in sample_tokens: + visualize_sample(self.nusc, + sample_token, + self.gt_boxes if self.eval_set != 'test' else EvalBoxes(), + # Don't render test GT. + self.pred_boxes, + eval_range=max(self.cfg.class_range.values()), + savepath=os.path.join(example_dir, '{}.png'.format(sample_token))) + + # Run evaluation. + metrics, metric_data_list = self.evaluate() + + return metrics + +class NuScenesEval(MotionEval): + """ + Dummy class for backward-compatibility. Same as MotionEval. + """ + + +class OccludedMotionEval(MotionEval): + """Evaluates motion prediction specifically for occluded objects (num_lidar_pts == 0). + + Loads GT filtered to only boxes with zero LiDAR points and applies class + + distance filtering without the standard num_pts >= 1 gate. + """ + + def __init__(self, + nusc: NuScenes, + config: DetectionConfig, + result_path: str, + eval_set: str, + output_dir: str = None, + verbose: bool = True, + seconds: int = 12): + self.nusc = nusc + self.result_path = result_path + self.eval_set = eval_set + self.output_dir = output_dir + self.verbose = verbose + self.cfg = config + + # Make dirs. + self.plot_dir = os.path.join(self.output_dir, 'plots') + if not os.path.isdir(self.output_dir): + os.makedirs(self.output_dir) + if not os.path.isdir(self.plot_dir): + os.makedirs(self.plot_dir) + + # Load data. + if verbose: + print('Initializing occluded motion evaluation') + self.pred_boxes, self.meta = load_prediction( + self.result_path, self.cfg.max_boxes_per_sample, MotionBox, verbose=verbose + ) + # Load GT restricted to boxes with zero lidar points. + self.gt_boxes = load_gt( + self.nusc, self.eval_set, MotionBox, verbose=verbose, + seconds=seconds, occluded_only=True + ) + + assert set(self.pred_boxes.sample_tokens) == set(self.gt_boxes.sample_tokens), \ + "Samples in split doesn't match samples in predictions." + + # Add center distances. + self.pred_boxes = add_center_dist(nusc, self.pred_boxes) + self.gt_boxes = add_center_dist(nusc, self.gt_boxes) + + # Filter predictions normally (class + distance + score). + if verbose: + print('Filtering predictions') + self.pred_boxes = filter_eval_boxes(nusc, self.pred_boxes, self.cfg.class_range, verbose=verbose) + + # For GT: apply only class + distance filter; skip num_pts check because + # every occluded box has num_pts == 0 and would be removed by the standard filter. + if verbose: + print('Filtering occluded ground truth (class + distance only)') + self.gt_boxes = self._filter_occluded_gt(self.gt_boxes) + + self.sample_tokens = self.gt_boxes.sample_tokens + + def _filter_occluded_gt(self, gt_boxes): + """Return GT boxes restricted to known classes within their distance range.""" + from collections import Counter + from nuscenes.eval.common.data_classes import EvalBoxes as _EvalBoxes + filtered = _EvalBoxes() + for sample_token in gt_boxes.sample_tokens: + boxes = [ + box for box in gt_boxes[sample_token] + if box.detection_name in self.cfg.class_range + and box.ego_dist < self.cfg.class_range[box.detection_name] + ] + filtered.add_boxes(sample_token, boxes) + total = sum(len(filtered[t]) for t in filtered.sample_tokens) + class_counts = Counter( + box.detection_name + for t in filtered.sample_tokens + for box in filtered[t] + ) + print(f'[Occluded Motion] GT occluded boxes: {total} | {dict(class_counts)}') + return filtered + + +class AllMotionEval(OccludedMotionEval): + """Evaluates motion prediction on all objects (visible + occluded, num_pts >= 0). + + Identical to OccludedMotionEval but loads GT without the occluded_only + restriction, so predictions are scored against every annotated object. + """ + + def __init__(self, + nusc: NuScenes, + config: DetectionConfig, + result_path: str, + eval_set: str, + output_dir: str = None, + verbose: bool = True, + seconds: int = 12): + self.nusc = nusc + self.result_path = result_path + self.eval_set = eval_set + self.output_dir = output_dir + self.verbose = verbose + self.cfg = config + + self.plot_dir = os.path.join(self.output_dir, 'plots') + if not os.path.isdir(self.output_dir): + os.makedirs(self.output_dir) + if not os.path.isdir(self.plot_dir): + os.makedirs(self.plot_dir) + + if verbose: + print('Initializing all-objects motion evaluation') + self.pred_boxes, self.meta = load_prediction( + self.result_path, self.cfg.max_boxes_per_sample, MotionBox, verbose=verbose + ) + # Load GT for all objects (visible + occluded). + self.gt_boxes = load_gt( + self.nusc, self.eval_set, MotionBox, verbose=verbose, + seconds=seconds + ) + + assert set(self.pred_boxes.sample_tokens) == set(self.gt_boxes.sample_tokens), \ + "Samples in split doesn't match samples in predictions." + + self.pred_boxes = add_center_dist(nusc, self.pred_boxes) + self.gt_boxes = add_center_dist(nusc, self.gt_boxes) + + if verbose: + print('Filtering predictions') + self.pred_boxes = filter_eval_boxes(nusc, self.pred_boxes, self.cfg.class_range, verbose=verbose) + + if verbose: + print('Filtering all ground truth (class + distance only)') + self.gt_boxes = self._filter_all_gt(self.gt_boxes) + + self.sample_tokens = self.gt_boxes.sample_tokens + + def _filter_all_gt(self, gt_boxes): + """Return GT boxes for all objects within their distance range (no num_pts filter).""" + from collections import Counter + from nuscenes.eval.common.data_classes import EvalBoxes as _EvalBoxes + filtered = _EvalBoxes() + for sample_token in gt_boxes.sample_tokens: + boxes = [ + box for box in gt_boxes[sample_token] + if box.detection_name in self.cfg.class_range + and box.ego_dist < self.cfg.class_range[box.detection_name] + ] + filtered.add_boxes(sample_token, boxes) + total = sum(len(filtered[t]) for t in filtered.sample_tokens) + class_counts = Counter( + box.detection_name + for t in filtered.sample_tokens + for box in filtered[t] + ) + print(f'[All Motion] GT all boxes: {total} | {dict(class_counts)}') + return filtered + + +if __name__ == "__main__": + + # Settings. + parser = argparse.ArgumentParser(description='Evaluate nuScenes detection results.', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('result_path', type=str, help='The submission as a JSON file.') + parser.add_argument('--output_dir', type=str, default='~/nuscenes-metrics', + help='Folder to store result metrics, graphs and example visualizations.') + parser.add_argument('--eval_set', type=str, default='val', + help='Which dataset split to evaluate on, train, val or test.') + parser.add_argument('--dataroot', type=str, default='/data/sets/nuscenes', + help='Default nuScenes data directory.') + parser.add_argument('--version', type=str, default='v1.0-trainval', + help='Which version of the nuScenes dataset to evaluate on, e.g. v1.0-trainval.') + parser.add_argument('--config_path', type=str, default='', + help='Path to the configuration file.' + 'If no path given, the CVPR 2019 configuration will be used.') + parser.add_argument('--plot_examples', type=int, default=10, + help='How many example visualizations to write to disk.') + parser.add_argument('--render_curves', type=int, default=1, + help='Whether to render PR and TP curves to disk.') + parser.add_argument('--verbose', type=int, default=1, + help='Whether to print to stdout.') + args = parser.parse_args() + + result_path_ = os.path.expanduser(args.result_path) + output_dir_ = os.path.expanduser(args.output_dir) + eval_set_ = args.eval_set + dataroot_ = args.dataroot + version_ = args.version + config_path = args.config_path + plot_examples_ = args.plot_examples + render_curves_ = bool(args.render_curves) + verbose_ = bool(args.verbose) + + if config_path == '': + cfg_ = config_factory('detection_cvpr_2019') + else: + with open(config_path, 'r') as _f: + cfg_ = DetectionConfig.deserialize(json.load(_f)) + + nusc_ = NuScenes(version=version_, verbose=verbose_, dataroot=dataroot_) + nusc_eval = DetectionEval(nusc_, config=cfg_, result_path=result_path_, eval_set=eval_set_, + output_dir=output_dir_, verbose=verbose_) + nusc_eval.main(plot_examples=plot_examples_, render_curves=render_curves_) diff --git a/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_utils.py b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_utils.py new file mode 100644 index 0000000..2a522b4 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_utils.py @@ -0,0 +1,680 @@ +# nuScenes dev-kit. +# Code written by Holger Caesar & Oscar Beijbom, 2018. + +import argparse +import json +import os +import random +import time +import tqdm +from typing import Tuple, Dict, Any, Callable + +import numpy as np + +from nuscenes import NuScenes +from nuscenes.eval.common.config import config_factory +from nuscenes.eval.common.data_classes import EvalBoxes +from nuscenes.eval.common.loaders import add_center_dist, filter_eval_boxes +from nuscenes.eval.detection.algo import calc_ap, calc_tp +from nuscenes.eval.detection.constants import DETECTION_NAMES, ATTRIBUTE_NAMES, TP_METRICS +from nuscenes.eval.detection.data_classes import DetectionConfig, DetectionMetrics, DetectionBox, \ + DetectionMetricDataList, DetectionMetricData +from nuscenes.eval.detection.render import summary_plot, class_pr_curve, class_tp_curve, dist_pr_curve, visualize_sample +from nuscenes.prediction import PredictHelper, convert_local_coords_to_global +from nuscenes.utils.splits import create_splits_scenes +from nuscenes.eval.detection.utils import category_to_detection_name +from nuscenes.eval.common.utils import quaternion_yaw, Quaternion +from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean + + +motion_name_mapping = { + 'car': 'car', + 'truck': 'car', + 'construction_vehicle': 'car', + 'bus': 'car', + 'trailer': 'car', + 'motorcycle': 'car', + 'bicycle': 'car', + 'pedestrian': 'pedestrian', + 'traffic_cone': 'barrier', + 'barrier': 'barrier', +} + + +class MotionBox(DetectionBox): + """ Data class used during detection evaluation. Can be a prediction or ground truth.""" + + def __init__(self, + sample_token: str = "", + translation: Tuple[float, float, float] = (0, 0, 0), + size: Tuple[float, float, float] = (0, 0, 0), + rotation: Tuple[float, float, float, float] = (0, 0, 0, 0), + velocity: Tuple[float, float] = (0, 0), + ego_translation: [float, float, float] = (0, 0, 0), # Translation to ego vehicle in meters. + num_pts: int = -1, # Nbr. LIDAR or RADAR inside the box. Only for gt boxes. + detection_name: str = 'car', # The class name used in the detection challenge. + detection_score: float = -1.0, # GT samples do not have a score. + attribute_name: str = '', # Box attribute. Each box can have at most 1 attribute. + traj=None, + traj_score=None): + + super().__init__(sample_token, translation, size, rotation, velocity, ego_translation, num_pts) + + assert detection_name is not None, 'Error: detection_name cannot be empty!' + assert detection_name in DETECTION_NAMES, 'Error: Unknown detection_name %s' % detection_name + + assert attribute_name in ATTRIBUTE_NAMES or attribute_name == '', \ + 'Error: Unknown attribute_name %s' % attribute_name + + assert type(detection_score) == float, 'Error: detection_score must be a float!' + assert not np.any(np.isnan(detection_score)), 'Error: detection_score may not be NaN!' + + # Assign. + self.detection_name = detection_name + self.detection_score = detection_score + self.attribute_name = attribute_name + self.traj = traj + self.traj_score = traj_score + + def __eq__(self, other): + return (self.sample_token == other.sample_token and + self.translation == other.translation and + self.size == other.size and + self.rotation == other.rotation and + self.velocity == other.velocity and + self.ego_translation == other.ego_translation and + self.num_pts == other.num_pts and + self.detection_name == other.detection_name and + self.detection_score == other.detection_score and + self.attribute_name == other.attribute_name and + np.all(self.traj == other.traj)) + + def serialize(self) -> dict: + """ Serialize instance into json-friendly format. """ + return { + 'sample_token': self.sample_token, + 'translation': self.translation, + 'size': self.size, + 'rotation': self.rotation, + 'velocity': self.velocity, + 'ego_translation': self.ego_translation, + 'num_pts': self.num_pts, + 'detection_name': self.detection_name, + 'detection_score': self.detection_score, + 'attribute_name': self.attribute_name, + 'traj': self.traj, + 'traj_score': self.traj_score, + } + + @classmethod + def deserialize(cls, content: dict): + """ Initialize from serialized content. """ + return cls(sample_token=content['sample_token'], + translation=tuple(content['translation']), + size=tuple(content['size']), + rotation=tuple(content['rotation']), + velocity=tuple(content['velocity']), + ego_translation=(0.0, 0.0, 0.0) if 'ego_translation' not in content + else tuple(content['ego_translation']), + num_pts=-1 if 'num_pts' not in content else int(content['num_pts']), + detection_name=content['detection_name'], + detection_score=-1.0 if 'detection_score' not in content else float(content['detection_score']), + attribute_name=content['attribute_name'], + traj=content['trajs'], + traj_score=content.get('trajs_score', None),) + + +def load_prediction(result_path: str, max_boxes_per_sample: int, box_cls, verbose: bool = False) \ + -> Tuple[EvalBoxes, Dict]: + """ + Loads object predictions from file. + :param result_path: Path to the .json result file provided by the user. + :param max_boxes_per_sample: Maximim number of boxes allowed per sample. + :param box_cls: Type of box to load, e.g. DetectionBox or TrackingBox. + :param verbose: Whether to print messages to stdout. + :return: The deserialized results and meta data. + """ + + # Load from file and check that the format is correct. + # with open(result_path) as f: + # data = json.load(f) + data = result_path + assert 'results' in data, 'Error: No field `results` in result file. Please note that the result format changed.' \ + 'See https://www.nuscenes.org/object-detection for more information.' + + # motion name mapping + for key in data['results'].keys(): + for i in range(len(data['results'][key])): + cls_name = data['results'][key][i]['detection_name'] + if cls_name in motion_name_mapping: + cls_name = motion_name_mapping[cls_name] + data['results'][key][i]['detection_name'] = cls_name + + # Deserialize results and get meta data. + all_results = EvalBoxes.deserialize(data['results'], box_cls) + meta = data['meta'] + if verbose: + print("Loaded results from {}. Found detections for {} samples." + .format(result_path, len(all_results.sample_tokens))) + + # Check that each sample has no more than x predicted boxes. + for sample_token in all_results.sample_tokens: + assert len(all_results.boxes[sample_token]) <= max_boxes_per_sample, \ + "Error: Only <= %d boxes per sample allowed!" % max_boxes_per_sample + + return all_results, meta + + +def load_gt(nusc: NuScenes, eval_split: str, box_cls, verbose: bool = False, seconds: int = 12, occluded_only: bool = False) -> EvalBoxes: + """ + Loads ground truth boxes from DB. + :param nusc: A NuScenes instance. + :param eval_split: The evaluation split for which we load GT boxes. + :param box_cls: Type of box to load, e.g. DetectionBox or TrackingBox. + :param verbose: Whether to print messages to stdout. + :return: The GT boxes. + """ + predict_helper = PredictHelper(nusc) + # Init. + if box_cls == MotionBox: + attribute_map = {a['token']: a['name'] for a in nusc.attribute} + + if verbose: + print('Loading annotations for {} split from nuScenes version: {}'.format(eval_split, nusc.version)) + # Read out all sample_tokens in DB. + sample_tokens_all = [s['token'] for s in nusc.sample] + assert len(sample_tokens_all) > 0, "Error: Database has no samples!" + + # Only keep samples from this split. + splits = create_splits_scenes() + + # Check compatibility of split with nusc_version. + version = nusc.version + if eval_split in {'train', 'val', 'train_detect', 'train_track'}: + assert version.endswith('trainval'), \ + 'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version) + elif eval_split in {'mini_train', 'mini_val'}: + assert version.endswith('mini'), \ + 'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version) + elif eval_split == 'test': + assert version.endswith('test'), \ + 'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version) + else: + raise ValueError('Error: Requested split {} which this function cannot map to the correct NuScenes version.' + .format(eval_split)) + + if eval_split == 'test': + # Check that you aren't trying to cheat :). + assert len(nusc.sample_annotation) > 0, \ + 'Error: You are trying to evaluate on the test set but you do not have the annotations!' + + sample_tokens = [] + for sample_token in sample_tokens_all: + scene_token = nusc.get('sample', sample_token)['scene_token'] + scene_record = nusc.get('scene', scene_token) + if scene_record['name'] in splits[eval_split]: + sample_tokens.append(sample_token) + + all_annotations = EvalBoxes() + + # Load annotations and filter predictions and annotations. + tracking_id_set = set() + for sample_token in tqdm.tqdm(sample_tokens, leave=verbose): + + sample = nusc.get('sample', sample_token) + sample_annotation_tokens = sample['anns'] + + sample_boxes = [] + for sample_annotation_token in sample_annotation_tokens: + + sample_annotation = nusc.get('sample_annotation', sample_annotation_token) + if box_cls == MotionBox: + # Get label name in detection task and filter unused labels. + detection_name = category_to_detection_name(sample_annotation['category_name']) + # motion name mapping + if detection_name in motion_name_mapping: + detection_name = motion_name_mapping[detection_name] + + if detection_name is None: + continue + + # Get attribute_name. + attr_tokens = sample_annotation['attribute_tokens'] + attr_count = len(attr_tokens) + if attr_count == 0: + attribute_name = '' + elif attr_count == 1: + attribute_name = attribute_map[attr_tokens[0]] + else: + raise Exception('Error: GT annotations must not have more than one attribute!') + + # get future trajs + instance_token = nusc.get('sample_annotation', sample_annotation['token'])['instance_token'] + fut_traj_local = predict_helper.get_future_for_agent( + instance_token, + sample_token, + seconds=seconds, + in_agent_frame=True + ) + if fut_traj_local.shape[0] > 0: + _, boxes, _ = nusc.get_sample_data(sample['data']['LIDAR_TOP'], selected_anntokens=[sample_annotation['token']]) + box = boxes[0] + trans = box.center + rot = Quaternion(matrix=box.rotation_matrix) + fut_traj_scence_centric = convert_local_coords_to_global(fut_traj_local, trans, rot) + else: + fut_traj_scence_centric = np.zeros((0,)) + + num_pts = sample_annotation['num_lidar_pts'] + sample_annotation['num_radar_pts'] + if occluded_only and num_pts > 0: + continue + sample_boxes.append( + box_cls( + sample_token=sample_token, + translation=sample_annotation['translation'], + size=sample_annotation['size'], + rotation=sample_annotation['rotation'], + velocity=nusc.box_velocity(sample_annotation['token'])[:2], + num_pts=num_pts, + detection_name=detection_name, + detection_score=-1.0, # GT samples do not have a score. + attribute_name=attribute_name, + traj=fut_traj_scence_centric + ) + ) + elif box_cls == TrackingBox: + # Use nuScenes token as tracking id. + tracking_id = sample_annotation['instance_token'] + tracking_id_set.add(tracking_id) + + # Get label name in detection task and filter unused labels. + # Import locally to avoid errors when motmetrics package is not installed. + from nuscenes.eval.tracking.utils import category_to_tracking_name + tracking_name = category_to_tracking_name(sample_annotation['category_name']) + if tracking_name is None: + continue + + sample_boxes.append( + box_cls( + sample_token=sample_token, + translation=sample_annotation['translation'], + size=sample_annotation['size'], + rotation=sample_annotation['rotation'], + velocity=nusc.box_velocity(sample_annotation['token'])[:2], + num_pts=sample_annotation['num_lidar_pts'] + sample_annotation['num_radar_pts'], + tracking_id=tracking_id, + tracking_name=tracking_name, + tracking_score=-1.0 # GT samples do not have a score. + ) + ) + else: + raise NotImplementedError('Error: Invalid box_cls %s!' % box_cls) + + all_annotations.add_boxes(sample_token, sample_boxes) + + if verbose: + print("Loaded ground truth annotations for {} samples.".format(len(all_annotations.sample_tokens))) + + return all_annotations + + +def accumulate(gt_boxes: EvalBoxes, + pred_boxes: EvalBoxes, + class_name: str, + dist_fcn: Callable, + dist_th: float, + verbose: bool = False) -> DetectionMetricData: + """ + Average Precision over predefined different recall thresholds for a single distance threshold. + The recall/conf thresholds and other raw metrics will be used in secondary metrics. + :param gt_boxes: Maps every sample_token to a list of its sample_annotations. + :param pred_boxes: Maps every sample_token to a list of its sample_results. + :param class_name: Class to compute AP on. + :param dist_fcn: Distance function used to match detections and ground truths. + :param dist_th: Distance threshold for a match. + :param verbose: If true, print debug messages. + :return: (average_prec, metrics). The average precision value and raw data for a number of metrics. + """ + # --------------------------------------------- + # Organize input and initialize accumulators. + # --------------------------------------------- + + # Count the positives. + npos = len([1 for gt_box in gt_boxes.all if gt_box.detection_name == class_name]) + if verbose: + print("Found {} GT of class {} out of {} total across {} samples.". + format(npos, class_name, len(gt_boxes.all), len(gt_boxes.sample_tokens))) + + # For missing classes in the GT, return a data structure corresponding to no predictions. + if npos == 0: + return DetectionMetricData.no_predictions(), 0 + + # Organize the predictions in a single list. + pred_boxes_list = [box for box in pred_boxes.all if box.detection_name == class_name] + pred_confs = [box.detection_score for box in pred_boxes_list] + + if verbose: + print("Found {} PRED of class {} out of {} total across {} samples.". + format(len(pred_confs), class_name, len(pred_boxes.all), len(pred_boxes.sample_tokens))) + + # Sort by confidence. + sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(pred_confs))][::-1] + + # Do the actual matching. + tp = [] # Accumulator of true positives. + fp = [] # Accumulator of false positives. + conf = [] # Accumulator of confidences. + hit = 0 # Accumulator of matched and hit + + # match_data holds the extra metrics we calculate for each match. + match_data = {'conf': [], + 'min_ade': [], + 'min_fde': [], + 'miss_rate': [], + 'top1_fde': [], + 'brier_min_fde': []} + + # --------------------------------------------- + # Match and accumulate match data. + # --------------------------------------------- + + taken = set() # Initially no gt bounding box is matched. + for ind in sortind: + pred_box = pred_boxes_list[ind] + min_dist = np.inf + match_gt_idx = None + + for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]): + + # Find closest match among ground truth boxes + if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken: + this_distance = dist_fcn(gt_box, pred_box) + if this_distance < min_dist: + min_dist = this_distance + match_gt_idx = gt_idx + + # If the closest match is close enough according to threshold we have a match! + is_match = min_dist < dist_th + + if is_match: + taken.add((pred_box.sample_token, match_gt_idx)) + + # Update tp, fp and confs. + tp.append(1) + fp.append(0) + conf.append(pred_box.detection_score) + + # Since it is a match, update match data also. + gt_box_match = gt_boxes[pred_box.sample_token][match_gt_idx] + + match_data['conf'].append(pred_box.detection_score) + + minade, minfde, mr, top1_fde, brier_min_fde = prediction_metrics(gt_box_match, pred_box) + match_data['min_ade'].append(minade) + match_data['min_fde'].append(minfde) + match_data['miss_rate'].append(mr) + match_data['top1_fde'].append(top1_fde) + match_data['brier_min_fde'].append(brier_min_fde) + + if minfde < 2.0: + hit += 1 + + else: + # No match. Mark this as a false positive. + tp.append(0) + fp.append(1) + conf.append(pred_box.detection_score) + + # Check if we have any matches. If not, just return a "no predictions" array. + if len(match_data['min_ade']) == 0: + return MotionMetricData.no_predictions() + + # Accumulate. + N_tp = np.sum(tp) + N_fp = np.sum(fp) + tp = np.cumsum(tp).astype(float) + fp = np.cumsum(fp).astype(float) + conf = np.array(conf) + + # Calculate precision and recall. + prec = tp / (fp + tp) + rec = tp / float(npos) + + rec_interp = np.linspace(0, 1, DetectionMetricData.nelem) # 101 steps, from 0% to 100% recall. + prec = np.interp(rec_interp, rec, prec, right=0) + conf = np.interp(rec_interp, rec, conf, right=0) + rec = rec_interp + + # --------------------------------------------- + # Re-sample the match-data to match, prec, recall and conf. + # --------------------------------------------- + confs = np.array(match_data['conf']) + all_same_conf = confs.max() == confs.min() + + for key in match_data.keys(): + if key == "conf": + continue # Confidence is used as reference to align with fp and tp. So skip in this step. + + arr = np.array(match_data[key]) + if all_same_conf: + # When all predictions have identical confidence (e.g. a GT-oracle model), + # the recall-weighted np.interp degenerates because xp is constant and + # np.interp requires a strictly-increasing sequence. Fall back to the + # plain mean replicated across all recall levels. + match_data[key] = np.full(DetectionMetricData.nelem, arr.mean()) + else: + # For each match_data, we first calculate the accumulated mean. + tmp = cummean(arr) + # Then interpolate based on the confidences. (Note reversing since np.interp needs increasing arrays) + match_data[key] = np.interp(conf[::-1], confs[::-1], tmp[::-1])[::-1] + EPA = (hit - 0.5 * N_fp) / npos + + ## match based on traj + traj_matched = 0 + taken = set() # Initially no gt bounding box is matched. + for ind in sortind: + pred_box = pred_boxes_list[ind] + min_dist = np.inf + match_gt_idx = None + + for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]): + + # Find closest match among ground truth boxes + if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken: + this_distance = dist_fcn(gt_box, pred_box) + if this_distance < min_dist: + min_dist = this_distance + match_gt_idx = gt_idx + fde_distance = traj_fde(gt_box, pred_box, final_step=12) + + # If the closest match is close enough according to threshold we have a match! + is_match = min_dist < dist_th and fde_distance < 2.0 + if is_match: + taken.add((pred_box.sample_token, match_gt_idx)) + traj_matched += 1 + EPA_ = (traj_matched - 0.5 * N_fp) / npos ## same as UniAD + + # --------------------------------------------- + # Done. Instantiate MetricData and return + # --------------------------------------------- + return MotionMetricData(recall=rec, + precision=prec, + confidence=conf, + min_ade_err=match_data['min_ade'], + min_fde_err=match_data['min_fde'], + miss_rate_err=match_data['miss_rate'], + top1_fde_err=match_data['top1_fde'], + brier_min_fde_err=match_data['brier_min_fde']), EPA, EPA_ + + +def prediction_metrics(gt_box_match, pred_box, miss_thresh=2): + gt_traj = np.array(gt_box_match.traj) + pred_traj = np.array(pred_box.traj) + + valid_step = gt_traj.shape[0] + if valid_step <= 0: + return 0, 0, 0, 0, 0 + + pred_traj_valid = pred_traj[:, :valid_step, :] + dist = np.linalg.norm(pred_traj_valid - gt_traj[np.newaxis], axis=2) + + minade = dist.mean(axis=1).min() + minfde = dist[:, -1].min() + mr = dist.max(axis=1).min() > miss_thresh + + # Top-1 FDE: FDE of the highest-confidence mode. + # Brier-minFDE: minFDE + (1 - p_best)^2, where p_best is the normalized + # probability assigned to the mode closest to GT (nuScenes leaderboard metric). + traj_score = getattr(pred_box, 'traj_score', None) + if traj_score is not None and len(traj_score) == pred_traj.shape[0]: + scores = np.array(traj_score, dtype=np.float64) + top1_idx = int(np.argmax(scores)) + top1_fde = float(dist[top1_idx, -1]) + + scores_sum = scores.sum() + probs = scores / scores_sum if scores_sum > 1e-6 else np.ones(len(scores)) / len(scores) + best_mode_idx = int(np.argmin(dist[:, -1])) + p_best = float(probs[best_mode_idx]) + brier_min_fde = minfde + (1.0 - p_best) ** 2 + else: + # No per-mode scores available: fall back to min-FDE for top1, + # and worst-case confidence penalty for Brier-minFDE. + top1_fde = minfde + brier_min_fde = minfde + 1.0 + + return minade, minfde, mr, top1_fde, brier_min_fde + +def traj_fde(gt_box, pred_box, final_step): + if gt_box.traj.shape[0] <= 0: + return np.inf + final_step = min(gt_box.traj.shape[0], final_step) + gt_final = gt_box.traj[None, final_step-1] + pred_final = np.array(pred_box.traj)[:,final_step-1,:] + err = gt_final - pred_final + err = np.sqrt(np.sum(np.square(gt_final - pred_final), axis=-1)) + return np.min(err) + + +class MotionMetricDataList(DetectionMetricDataList): + """ This stores a set of MetricData in a dict indexed by (name, match-distance). """ + @classmethod + def deserialize(cls, content: dict): + mdl = cls() + for key, md in content.items(): + name, distance = key.split(':') + mdl.set(name, float(distance), MotionMetricData.deserialize(md)) + return mdl + +class MotionMetricData(DetectionMetricData): + """ This class holds accumulated and interpolated data required to calculate the detection metrics. """ + + nelem = 101 + + def __init__(self, + recall: np.array, + precision: np.array, + confidence: np.array, + min_ade_err: np.array, + min_fde_err: np.array, + miss_rate_err: np.array, + top1_fde_err: np.array, + brier_min_fde_err: np.array): + + # Assert lengths. + assert len(recall) == self.nelem + assert len(precision) == self.nelem + assert len(confidence) == self.nelem + assert len(min_ade_err) == self.nelem + assert len(min_fde_err) == self.nelem + assert len(miss_rate_err) == self.nelem + assert len(top1_fde_err) == self.nelem + assert len(brier_min_fde_err) == self.nelem + + # Assert ordering. + assert all(confidence == sorted(confidence, reverse=True)) # Confidences should be descending. + assert all(recall == sorted(recall)) # Recalls should be ascending. + + # Set attributes explicitly to help IDEs figure out what is going on. + self.recall = recall + self.precision = precision + self.confidence = confidence + self.min_ade_err = min_ade_err + self.min_fde_err = min_fde_err + self.miss_rate_err = miss_rate_err + self.top1_fde_err = top1_fde_err + self.brier_min_fde_err = brier_min_fde_err + + def __eq__(self, other): + eq = True + for key in self.serialize().keys(): + eq = eq and np.array_equal(getattr(self, key), getattr(other, key)) + return eq + + @property + def max_recall_ind(self): + """ Returns index of max recall achieved. """ + + # Last instance of confidence > 0 is index of max achieved recall. + non_zero = np.nonzero(self.confidence)[0] + if len(non_zero) == 0: # If there are no matches, all the confidence values will be zero. + max_recall_ind = 0 + else: + max_recall_ind = non_zero[-1] + + return max_recall_ind + + @property + def max_recall(self): + """ Returns max recall achieved. """ + + return self.recall[self.max_recall_ind] + + def serialize(self): + """ Serialize instance into json-friendly format. """ + return { + 'recall': self.recall.tolist(), + 'precision': self.precision.tolist(), + 'confidence': self.confidence.tolist(), + 'min_ade_err': self.min_ade_err.tolist(), + 'min_fde_err': self.min_fde_err.tolist(), + 'miss_rate_err': self.miss_rate_err.tolist(), + 'top1_fde_err': self.top1_fde_err.tolist(), + 'brier_min_fde_err': self.brier_min_fde_err.tolist(), + } + + @classmethod + def deserialize(cls, content: dict): + """ Initialize from serialized content. """ + return cls(recall=np.array(content['recall']), + precision=np.array(content['precision']), + confidence=np.array(content['confidence']), + min_ade_err=np.array(content['min_ade_err']), + min_fde_err=np.array(content['min_fde_err']), + miss_rate_err=np.array(content['miss_rate_err']), + top1_fde_err=np.array(content['top1_fde_err']), + brier_min_fde_err=np.array(content['brier_min_fde_err'])) + + @classmethod + def no_predictions(cls): + """ Returns a md instance corresponding to having no predictions. """ + return cls(recall=np.linspace(0, 1, cls.nelem), + precision=np.zeros(cls.nelem), + confidence=np.zeros(cls.nelem), + min_ade_err=np.ones(cls.nelem), + min_fde_err=np.ones(cls.nelem), + miss_rate_err=np.ones(cls.nelem), + top1_fde_err=np.ones(cls.nelem), + brier_min_fde_err=np.ones(cls.nelem) * 2.0) + + @classmethod + def random_md(cls): + """ Returns an md instance corresponding to a random results. """ + return cls(recall=np.linspace(0, 1, cls.nelem), + precision=np.random.random(cls.nelem), + confidence=np.linspace(0, 1, cls.nelem)[::-1], + min_ade_err=np.random.random(cls.nelem), + min_fde_err=np.random.random(cls.nelem), + miss_rate_err=np.random.random(cls.nelem), + top1_fde_err=np.random.random(cls.nelem), + brier_min_fde_err=np.random.random(cls.nelem)) + diff --git a/projects/mmdet3d_plugin/datasets/evaluation/planning/planning_eval.py b/projects/mmdet3d_plugin/datasets/evaluation/planning/planning_eval.py new file mode 100644 index 0000000..daeb457 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/evaluation/planning/planning_eval.py @@ -0,0 +1,209 @@ +from tqdm import tqdm +import torch +import torch.nn as nn +import numpy as np +from shapely.geometry import Polygon + +from mmcv.utils import print_log +from mmdet.datasets import build_dataset, build_dataloader + +from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners + + +def check_collision(ego_box, boxes): + ''' + ego_box: tensor with shape [7], [x, y, z, w, l, h, yaw] + boxes: tensor with shape [N, 7] + ''' + if boxes.shape[0] == 0: + return False + + # follow uniad, add a 0.5m offset + ego_box[0] += 0.5 * torch.cos(ego_box[6]) + ego_box[1] += 0.5 * torch.sin(ego_box[6]) + ego_corners_box = box3d_to_corners(ego_box.unsqueeze(0))[0, [0, 3, 7, 4], :2] + corners_box = box3d_to_corners(boxes)[:, [0, 3, 7, 4], :2] + ego_poly = Polygon([(point[0], point[1]) for point in ego_corners_box]) + for i in range(len(corners_box)): + box_poly = Polygon([(point[0], point[1]) for point in corners_box[i]]) + collision = ego_poly.intersects(box_poly) + if collision: + return True + + return False + +def get_yaw(traj): + start = traj[0] + end = traj[-1] + dist = torch.linalg.norm(end - start, dim=-1) + if dist < 0.5: + return traj.new_ones(traj.shape[0]) * np.pi / 2 + + zeros = traj.new_zeros((1, 2)) + traj_cat = torch.cat([zeros, traj], dim=0) + yaw = traj.new_zeros(traj.shape[0]+1) + yaw[..., 1:-1] = torch.atan2( + traj_cat[..., 2:, 1] - traj_cat[..., :-2, 1], + traj_cat[..., 2:, 0] - traj_cat[..., :-2, 0], + ) + yaw[..., -1] = torch.atan2( + traj_cat[..., -1, 1] - traj_cat[..., -2, 1], + traj_cat[..., -1, 0] - traj_cat[..., -2, 0], + ) + return yaw[1:] + +class PlanningMetric(): + def __init__( + self, + n_future=6, + compute_on_step: bool = False, + ): + self.W = 1.85 + self.H = 4.084 + + self.n_future = n_future + self.reset() + + def reset(self): + self.obj_col = torch.zeros(self.n_future) + self.obj_box_col = torch.zeros(self.n_future) + self.obj_box_col_occluded = torch.zeros(self.n_future) + self.obj_box_col_all = torch.zeros(self.n_future) + self.L2 = torch.zeros(self.n_future) + self.total = torch.tensor(0) + self.total_occluded = torch.tensor(0) + + def evaluate_single_coll(self, traj, fut_boxes): + n_future = traj.shape[0] + yaw = get_yaw(traj) + ego_box = traj.new_zeros((n_future, 7)) + ego_box[:, :2] = traj + ego_box[:, 3:6] = ego_box.new_tensor([self.H, self.W, 1.56]) + ego_box[:, 6] = yaw + collision = torch.zeros(n_future, dtype=torch.bool) + + for t in range(n_future): + ego_box_t = ego_box[t].clone() + boxes = fut_boxes[t][0].clone() + collision[t] = check_collision(ego_box_t, boxes) + return collision + + def evaluate_coll(self, trajs, gt_trajs, fut_boxes): + B, n_future, _ = trajs.shape + trajs = trajs * torch.tensor([-1, 1], device=trajs.device) + gt_trajs = gt_trajs * torch.tensor([-1, 1], device=gt_trajs.device) + + obj_coll_sum = torch.zeros(n_future, device=trajs.device) + obj_box_coll_sum = torch.zeros(n_future, device=trajs.device) + + assert B == 1, 'only supprt bs=1' + for i in range(B): + gt_box_coll = self.evaluate_single_coll(gt_trajs[i], fut_boxes) + box_coll = self.evaluate_single_coll(trajs[i], fut_boxes) + box_coll = torch.logical_and(box_coll, torch.logical_not(gt_box_coll)) + + obj_coll_sum += gt_box_coll.long() + obj_box_coll_sum += box_coll.long() + + return obj_coll_sum, obj_box_coll_sum + + def compute_L2(self, trajs, gt_trajs, gt_trajs_mask): + ''' + trajs: torch.Tensor (B, n_future, 3) + gt_trajs: torch.Tensor (B, n_future, 3) + ''' + return torch.sqrt((((trajs[:, :, :2] - gt_trajs[:, :, :2]) ** 2) * gt_trajs_mask).sum(dim=-1)) + + def update(self, trajs, gt_trajs, gt_trajs_mask, fut_boxes, fut_boxes_occluded=None): + assert trajs.shape == gt_trajs.shape + trajs[..., 0] = - trajs[..., 0] + gt_trajs[..., 0] = - gt_trajs[..., 0] + L2 = self.compute_L2(trajs, gt_trajs, gt_trajs_mask) + obj_coll_sum, obj_box_coll_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], fut_boxes) + + self.obj_col += obj_coll_sum + self.obj_box_col += obj_box_coll_sum + self.L2 += L2.sum(dim=0) + self.total += len(trajs) + + if fut_boxes_occluded is not None: + _, obj_box_coll_occ_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], fut_boxes_occluded) + self.obj_box_col_occluded += obj_box_coll_occ_sum + has_occluded = any(boxes[0].shape[0] > 0 for boxes in fut_boxes_occluded) + self.total_occluded += int(has_occluded) + + merged_boxes = [ + [torch.cat([fut_boxes[t][0], fut_boxes_occluded[t][0]], dim=0)] + for t in range(len(fut_boxes)) + ] + _, obj_box_coll_all_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], merged_boxes) + self.obj_box_col_all += obj_box_coll_all_sum + + def compute(self): + occ_denom = self.total_occluded if self.total_occluded > 0 else torch.tensor(1) + results = { + 'obj_col': self.obj_col / self.total, + 'obj_box_col': self.obj_box_col / self.total, + 'L2': self.L2 / self.total, + } + if self.total_occluded > 0: + results['occluded/obj_box_col'] = self.obj_box_col_occluded / occ_denom + results['all/obj_box_col'] = self.obj_box_col_all / self.total + return results + + +def planning_eval(results, eval_config, logger, with_occlusion=False): + dataset = build_dataset(eval_config) + dataloader = build_dataloader( + dataset, samples_per_gpu=1, workers_per_gpu=1, shuffle=False, dist=False) + planning_metrics = PlanningMetric() + occluded_samples = 0 + occluded_box_timesteps = 0 + for i, data in enumerate(tqdm(dataloader)): + sdc_planning = data['gt_ego_fut_trajs'].cumsum(dim=-2).unsqueeze(1) + sdc_planning_mask = data['gt_ego_fut_masks'].unsqueeze(-1).repeat(1, 1, 2).unsqueeze(1) + command = data['gt_ego_fut_cmd'].argmax(dim=-1).item() + fut_boxes = data['fut_boxes'] + fut_boxes_occluded = data.get('fut_boxes_occluded', None) if with_occlusion else None + if fut_boxes_occluded is not None: + sample_occ_count = sum(boxes[0].shape[0] for boxes in fut_boxes_occluded) + if sample_occ_count > 0: + occluded_samples += 1 + occluded_box_timesteps += sample_occ_count + if not sdc_planning_mask.all(): ## for incomplete gt, we do not count this sample + continue + res = results[i] + pred_sdc_traj = res['img_bbox']['final_planning'].unsqueeze(0) + planning_metrics.update(pred_sdc_traj[:, :6, :2], sdc_planning[0,:, :6, :2], sdc_planning_mask[0,:, :6, :2], fut_boxes, fut_boxes_occluded) + if with_occlusion: + print(f'[Occluded Planning] Samples with occluded future boxes: {occluded_samples} | Total occluded box-timestep instances: {occluded_box_timesteps}') + + planning_results = planning_metrics.compute() + planning_metrics.reset() + from prettytable import PrettyTable + planning_tab = PrettyTable() + metric_dict = {} + + planning_tab.field_names = [ + "metrics", "0.5s", "1.0s", "1.5s", "2.0s", "2.5s", "3.0s", "avg"] + for key in planning_results.keys(): + value = planning_results[key].tolist() + new_values = [] + for i in range(len(value)): + new_values.append(np.array(value[:i+1]).mean()) + value = new_values + avg = [value[1], value[3], value[5]] + avg = sum(avg) / len(avg) + value.append(avg) + metric_dict[key] = avg + row_value = [] + row_value.append(key) + for i in range(len(value)): + if 'col' in key: + row_value.append('%.3f' % float(value[i]*100) + '%') + else: + row_value.append('%.4f' % float(value[i])) + planning_tab.add_row(row_value) + + print_log('\n'+str(planning_tab), logger=logger) + return metric_dict diff --git a/projects/mmdet3d_plugin/datasets/map_utils/nuscmap_extractor.py b/projects/mmdet3d_plugin/datasets/map_utils/nuscmap_extractor.py new file mode 100644 index 0000000..a3792d4 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/map_utils/nuscmap_extractor.py @@ -0,0 +1,159 @@ +from shapely.geometry import LineString, box, Polygon +from shapely import ops, strtree + +import numpy as np +from nuscenes.map_expansion.map_api import NuScenesMap, NuScenesMapExplorer +from nuscenes.eval.common.utils import quaternion_yaw +from pyquaternion import Quaternion +from .utils import split_collections, get_drivable_area_contour, \ + get_ped_crossing_contour +from numpy.typing import NDArray +from typing import Dict, List, Tuple, Union + +class NuscMapExtractor(object): + """NuScenes map ground-truth extractor. + + Args: + data_root (str): path to nuScenes dataset + roi_size (tuple or list): bev range + """ + def __init__(self, data_root: str, roi_size: Union[List, Tuple]) -> None: + self.roi_size = roi_size + self.MAPS = ['boston-seaport', 'singapore-hollandvillage', + 'singapore-onenorth', 'singapore-queenstown'] + + self.nusc_maps = {} + self.map_explorer = {} + for loc in self.MAPS: + self.nusc_maps[loc] = NuScenesMap( + dataroot=data_root, map_name=loc) + self.map_explorer[loc] = NuScenesMapExplorer(self.nusc_maps[loc]) + + # local patch in nuScenes format + self.local_patch = box(-roi_size[0] / 2, -roi_size[1] / 2, + roi_size[0] / 2, roi_size[1] / 2) + + def _union_ped(self, ped_geoms: List[Polygon]) -> List[Polygon]: + ''' merge close ped crossings. + + Args: + ped_geoms (list): list of Polygon + + Returns: + union_ped_geoms (Dict): merged ped crossings + ''' + + def get_rec_direction(geom): + rect = geom.minimum_rotated_rectangle + rect_v_p = np.array(rect.exterior.coords)[:3] + rect_v = rect_v_p[1:]-rect_v_p[:-1] + v_len = np.linalg.norm(rect_v, axis=-1) + longest_v_i = v_len.argmax() + + return rect_v[longest_v_i], v_len[longest_v_i] + + tree = strtree.STRtree(ped_geoms) + index_by_id = dict((id(pt), i) for i, pt in enumerate(ped_geoms)) + + final_pgeom = [] + remain_idx = [i for i in range(len(ped_geoms))] + for i, pgeom in enumerate(ped_geoms): + + if i not in remain_idx: + continue + # update + remain_idx.pop(remain_idx.index(i)) + pgeom_v, pgeom_v_norm = get_rec_direction(pgeom) + final_pgeom.append(pgeom) + + for o in tree.query(pgeom): + o_idx = index_by_id[id(o)] + if o_idx not in remain_idx: + continue + + o_v, o_v_norm = get_rec_direction(o) + cos = pgeom_v.dot(o_v)/(pgeom_v_norm*o_v_norm) + if 1 - np.abs(cos) < 0.01: # theta < 8 degrees. + final_pgeom[-1] =\ + final_pgeom[-1].union(o) + # update + remain_idx.pop(remain_idx.index(o_idx)) + + results = [] + for p in final_pgeom: + results.extend(split_collections(p)) + return results + + def get_map_geom(self, + location: str, + translation: Union[List, NDArray], + rotation: Union[List, NDArray]) -> Dict[str, List[Union[LineString, Polygon]]]: + ''' Extract geometries given `location` and self pose, self may be lidar or ego. + + Args: + location (str): city name + translation (array): self2global translation, shape (3,) + rotation (array): self2global quaternion, shape (4, ) + + Returns: + geometries (Dict): extracted geometries by category. + ''' + + # (center_x, center_y, len_y, len_x) in nuscenes format + patch_box = (translation[0], translation[1], + self.roi_size[1], self.roi_size[0]) + rotation = Quaternion(rotation) + yaw = quaternion_yaw(rotation) / np.pi * 180 + + # get dividers + lane_dividers = self.map_explorer[location]._get_layer_line( + patch_box, yaw, 'lane_divider') + + road_dividers = self.map_explorer[location]._get_layer_line( + patch_box, yaw, 'road_divider') + + all_dividers = [] + for line in lane_dividers + road_dividers: + all_dividers += split_collections(line) + + # get ped crossings + ped_crossings = [] + ped = self.map_explorer[location]._get_layer_polygon( + patch_box, yaw, 'ped_crossing') + + for p in ped: + ped_crossings += split_collections(p) + # some ped crossings are split into several small parts + # we need to merge them + ped_crossings = self._union_ped(ped_crossings) + + ped_crossing_lines = [] + for p in ped_crossings: + # extract exteriors to get a closed polyline + line = get_ped_crossing_contour(p, self.local_patch) + if line is not None: + ped_crossing_lines.append(line) + + # get boundaries + # we take the union of road segments and lanes as drivable areas + # we don't take drivable area layer in nuScenes since its definition may be ambiguous + road_segments = self.map_explorer[location]._get_layer_polygon( + patch_box, yaw, 'road_segment') + lanes = self.map_explorer[location]._get_layer_polygon( + patch_box, yaw, 'lane') + union_roads = ops.unary_union(road_segments) + union_lanes = ops.unary_union(lanes) + drivable_areas = ops.unary_union([union_roads, union_lanes]) + + drivable_areas = split_collections(drivable_areas) + + # boundaries are defined as the contour of drivable areas + boundaries = get_drivable_area_contour(drivable_areas, self.roi_size) + + return dict( + divider=all_dividers, # List[LineString] + ped_crossing=ped_crossing_lines, # List[LineString] + boundary=boundaries, # List[LineString] + drivable_area=drivable_areas, # List[Polygon], + ) + diff --git a/projects/mmdet3d_plugin/datasets/map_utils/utils.py b/projects/mmdet3d_plugin/datasets/map_utils/utils.py new file mode 100644 index 0000000..7dac57a --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/map_utils/utils.py @@ -0,0 +1,119 @@ +from shapely.geometry import LineString, box, Polygon, LinearRing +from shapely.geometry.base import BaseGeometry +from shapely import ops +import numpy as np +from scipy.spatial import distance +from typing import List, Optional, Tuple +from numpy.typing import NDArray + +def split_collections(geom: BaseGeometry) -> List[Optional[BaseGeometry]]: + ''' Split Multi-geoms to list and check is valid or is empty. + + Args: + geom (BaseGeometry): geoms to be split or validate. + + Returns: + geometries (List): list of geometries. + ''' + assert geom.geom_type in ['MultiLineString', 'LineString', 'MultiPolygon', + 'Polygon', 'GeometryCollection'], f"got geom type {geom.geom_type}" + if 'Multi' in geom.geom_type: + outs = [] + for g in geom.geoms: + if g.is_valid and not g.is_empty: + outs.append(g) + return outs + else: + if geom.is_valid and not geom.is_empty: + return [geom,] + else: + return [] + +def get_drivable_area_contour(drivable_areas: List[Polygon], + roi_size: Tuple) -> List[LineString]: + ''' Extract drivable area contours to get list of boundaries. + + Args: + drivable_areas (list): list of drivable areas. + roi_size (tuple): bev range size + + Returns: + boundaries (List): list of boundaries. + ''' + max_x = roi_size[0] / 2 + max_y = roi_size[1] / 2 + + # a bit smaller than roi to avoid unexpected boundaries on edges + local_patch = box(-max_x + 0.2, -max_y + 0.2, max_x - 0.2, max_y - 0.2) + + exteriors = [] + interiors = [] + + for poly in drivable_areas: + exteriors.append(poly.exterior) + for inter in poly.interiors: + interiors.append(inter) + + results = [] + for ext in exteriors: + # NOTE: we make sure all exteriors are clock-wise + # such that each boundary's right-hand-side is drivable area + # and left-hand-side is walk way + + if ext.is_ccw: + ext = LinearRing(list(ext.coords)[::-1]) + lines = ext.intersection(local_patch) + if lines.geom_type == 'MultiLineString': + lines = ops.linemerge(lines) + assert lines.geom_type in ['MultiLineString', 'LineString'] + + results.extend(split_collections(lines)) + + for inter in interiors: + # NOTE: we make sure all interiors are counter-clock-wise + if not inter.is_ccw: + inter = LinearRing(list(inter.coords)[::-1]) + lines = inter.intersection(local_patch) + if lines.geom_type == 'MultiLineString': + lines = ops.linemerge(lines) + assert lines.geom_type in ['MultiLineString', 'LineString'] + + results.extend(split_collections(lines)) + + return results + +def get_ped_crossing_contour(polygon: Polygon, + local_patch: box) -> Optional[LineString]: + ''' Extract ped crossing contours to get a closed polyline. + Different from `get_drivable_area_contour`, this function ensures a closed polyline. + + Args: + polygon (Polygon): ped crossing polygon to be extracted. + local_patch (tuple): local patch params + + Returns: + line (LineString): a closed line + ''' + + ext = polygon.exterior + if not ext.is_ccw: + ext = LinearRing(list(ext.coords)[::-1]) + lines = ext.intersection(local_patch) + if lines.type != 'LineString': + # remove points in intersection results + lines = [l for l in lines.geoms if l.geom_type != 'Point'] + lines = ops.linemerge(lines) + + # same instance but not connected. + if lines.type != 'LineString': + ls = [] + for l in lines.geoms: + ls.append(np.array(l.coords)) + + lines = np.concatenate(ls, axis=0) + lines = LineString(lines) + if not lines.is_empty: + return lines + + return None + diff --git a/projects/mmdet3d_plugin/datasets/nuscenes_3d_dataset.py b/projects/mmdet3d_plugin/datasets/nuscenes_3d_dataset.py new file mode 100644 index 0000000..1dbdf9c --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/nuscenes_3d_dataset.py @@ -0,0 +1,1529 @@ +import random +import math +import os +from os import path as osp +import cv2 +import tempfile +import copy +import prettytable + +import numpy as np +import torch +from torch.utils.data import Dataset +import pyquaternion +from shapely.geometry import LineString +from nuscenes.utils.data_classes import Box as NuScenesBox +from nuscenes.eval.detection.config import config_factory as det_configs +from nuscenes.eval.common.config import config_factory as track_configs + +import mmcv +from mmcv.utils import print_log +from mmdet.datasets import DATASETS +from mmdet.datasets.pipelines import Compose +from .utils import ( + draw_lidar_bbox3d_on_img, + draw_lidar_bbox3d_on_bev, +) + + +@DATASETS.register_module() +class NuScenes3DDataset(Dataset): + DefaultAttribute = { + "car": "vehicle.parked", + "pedestrian": "pedestrian.moving", + "trailer": "vehicle.parked", + "truck": "vehicle.parked", + "bus": "vehicle.moving", + "motorcycle": "cycle.without_rider", + "construction_vehicle": "vehicle.parked", + "bicycle": "cycle.without_rider", + "barrier": "", + "traffic_cone": "", + } + ErrNameMapping = { + "trans_err": "mATE", + "scale_err": "mASE", + "orient_err": "mAOE", + "vel_err": "mAVE", + "attr_err": "mAAE", + } + CLASSES = ( + "car", + "truck", + "trailer", + "bus", + "construction_vehicle", + "bicycle", + "motorcycle", + "pedestrian", + "traffic_cone", + "barrier", + ) + MAP_CLASSES = ( + 'ped_crossing', + 'divider', + 'boundary', + ) + ID_COLOR_MAP = [ + (59, 59, 238), + (0, 255, 0), + (0, 0, 255), + (255, 255, 0), + (0, 255, 255), + (255, 0, 255), + (255, 255, 255), + (0, 127, 255), + (71, 130, 255), + (127, 127, 0), + ] + + def __init__( + self, + ann_file, + pipeline=None, + data_root=None, + classes=None, + map_classes=None, + load_interval=1, + with_velocity=True, + with_visibility=True, + modality=None, + test_mode=False, + det3d_eval_version="detection_cvpr_2019", + track3d_eval_version="tracking_nips_2019", + version="v1.0-trainval", + use_valid_flag=False, + use_gt_mask=True, + vis_score_threshold=0.25, + data_aug_conf=None, + sequences_split_num=1, + with_seq_flag=False, + keep_consistent_seq_aug=True, + work_dir=None, + eval_config=None, + ): + self.version = version + self.load_interval = load_interval + self.use_valid_flag = use_valid_flag + self.use_gt_mask = use_gt_mask + super().__init__() + self.data_root = data_root + self.ann_file = ann_file + self.test_mode = test_mode + self.modality = modality + self.box_mode_3d = 0 + + if classes is not None: + self.CLASSES = classes + if map_classes is not None: + self.MAP_CLASSES = map_classes + self.cat2id = {name: i for i, name in enumerate(self.CLASSES)} + self.data_infos = self.load_annotations(self.ann_file) + + if pipeline is not None: + self.pipeline = Compose(pipeline) + + self.with_velocity = with_velocity + self.with_visibility = with_visibility + self.det3d_eval_version = det3d_eval_version + self.det3d_eval_configs = det_configs(self.det3d_eval_version) + self.det3d_eval_configs.class_names = list(self.det3d_eval_configs.class_range.keys()) + self.track3d_eval_version = track3d_eval_version + self.track3d_eval_configs = track_configs(self.track3d_eval_version) + self.track3d_eval_configs.class_names = list(self.track3d_eval_configs.class_range.keys()) + if self.modality is None: + self.modality = dict( + use_camera=False, + use_lidar=True, + use_radar=False, + use_map=False, + use_external=False, + ) + self.vis_score_threshold = vis_score_threshold + + self.data_aug_conf = data_aug_conf + self.sequences_split_num = sequences_split_num + self.keep_consistent_seq_aug = keep_consistent_seq_aug + if with_seq_flag: + self._set_sequence_group_flag() + + self.work_dir = work_dir + self.eval_config = eval_config + + def __len__(self): + return len(self.data_infos) + + def _set_sequence_group_flag(self): + """ + Set each sequence to be a different group + """ + if self.sequences_split_num == -1: + self.flag = np.arange(len(self.data_infos)) + return + + res = [] + + curr_sequence = 0 + for idx in range(len(self.data_infos)): + if idx != 0 and len(self.data_infos[idx]["sweeps"]) == 0: + # Not first frame and # of sweeps is 0 -> new sequence + curr_sequence += 1 + res.append(curr_sequence) + + self.flag = np.array(res, dtype=np.int64) + + if self.sequences_split_num != 1: + if self.sequences_split_num == "all": + self.flag = np.array( + range(len(self.data_infos)), dtype=np.int64 + ) + else: + bin_counts = np.bincount(self.flag) + new_flags = [] + curr_new_flag = 0 + for curr_flag in range(len(bin_counts)): + curr_sequence_length = np.array( + list( + range( + 0, + bin_counts[curr_flag], + math.ceil( + bin_counts[curr_flag] + / self.sequences_split_num + ), + ) + ) + + [bin_counts[curr_flag]] + ) + + for sub_seq_idx in ( + curr_sequence_length[1:] - curr_sequence_length[:-1] + ): + for _ in range(sub_seq_idx): + new_flags.append(curr_new_flag) + curr_new_flag += 1 + + assert len(new_flags) == len(self.flag) + assert ( + len(np.bincount(new_flags)) + == len(np.bincount(self.flag)) * self.sequences_split_num + ) + self.flag = np.array(new_flags, dtype=np.int64) + + def get_augmentation(self): + if self.data_aug_conf is None: + return None + H, W = self.data_aug_conf["H"], self.data_aug_conf["W"] + fH, fW = self.data_aug_conf["final_dim"] + if not self.test_mode: + resize = np.random.uniform(*self.data_aug_conf["resize_lim"]) + resize_dims = (int(W * resize), int(H * resize)) + newW, newH = resize_dims + crop_h = ( + int( + (1 - np.random.uniform(*self.data_aug_conf["bot_pct_lim"])) + * newH + ) + - fH + ) + crop_w = int(np.random.uniform(0, max(0, newW - fW))) + crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) + flip = False + if self.data_aug_conf["rand_flip"] and np.random.choice([0, 1]): + flip = True + rotate = np.random.uniform(*self.data_aug_conf["rot_lim"]) + rotate_3d = np.random.uniform(*self.data_aug_conf["rot3d_range"]) + else: + resize = max(fH / H, fW / W) + resize_dims = (int(W * resize), int(H * resize)) + newW, newH = resize_dims + crop_h = ( + int((1 - np.mean(self.data_aug_conf["bot_pct_lim"])) * newH) + - fH + ) + crop_w = int(max(0, newW - fW) / 2) + crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) + flip = False + rotate = 0 + rotate_3d = 0 + aug_config = { + "resize": resize, + "resize_dims": resize_dims, + "crop": crop, + "flip": flip, + "rotate": rotate, + "rotate_3d": rotate_3d, + } + return aug_config + + def __getitem__(self, idx): + if isinstance(idx, dict): + aug_config = idx["aug_config"] + idx = idx["idx"] + else: + aug_config = self.get_augmentation() + data = self.get_data_info(idx) + data["aug_config"] = aug_config + data = self.pipeline(data) + return data + + def get_cat_ids(self, idx): + info = self.data_infos[idx] + if self.use_valid_flag and self.use_gt_mask: + mask = info["valid_flag"] + gt_names = set(info["gt_names"][mask]) + else: + gt_names = set(info["gt_names"]) + + cat_ids = [] + for name in gt_names: + if name in self.CLASSES: + cat_ids.append(self.cat2id[name]) + return cat_ids + + def load_annotations(self, ann_file): + data = mmcv.load(ann_file, file_format="pkl") + data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) + data_infos = data_infos[:: self.load_interval] + self.metadata = data["metadata"] + self.version = self.metadata["version"] + print(self.metadata) + return data_infos + + def anno2geom(self, annos): + map_geoms = {} + for label, anno_list in annos.items(): + map_geoms[label] = [] + for anno in anno_list: + geom = LineString(anno) + map_geoms[label].append(geom) + return map_geoms + + def get_data_info(self, index): + info = self.data_infos[index] + input_dict = dict( + token=info["token"], + map_location=info["map_location"], + pts_filename=info["lidar_path"], + sweeps=info["sweeps"], + timestamp=info["timestamp"] / 1e6, + lidar2ego_translation=info["lidar2ego_translation"], + lidar2ego_rotation=info["lidar2ego_rotation"], + ego2global_translation=info["ego2global_translation"], + ego2global_rotation=info["ego2global_rotation"], + ego_status=info['ego_status'].astype(np.float32), + map_infos=info["map_annos"], + ) + lidar2ego = np.eye(4) + lidar2ego[:3, :3] = pyquaternion.Quaternion( + info["lidar2ego_rotation"] + ).rotation_matrix + lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"]) + ego2global = np.eye(4) + ego2global[:3, :3] = pyquaternion.Quaternion( + info["ego2global_rotation"] + ).rotation_matrix + ego2global[:3, 3] = np.array(info["ego2global_translation"]) + input_dict["lidar2global"] = ego2global @ lidar2ego + + map_geoms = self.anno2geom(info["map_annos"]) + input_dict["map_geoms"] = map_geoms + + if self.modality["use_camera"]: + image_paths = [] + lidar2img_rts = [] + lidar2cam_rts = [] + cam_intrinsic = [] + for cam_type, cam_info in info["cams"].items(): + image_paths.append(cam_info["data_path"]) + # obtain lidar to image transformation matrix + lidar2cam_r = np.linalg.inv(cam_info["sensor2lidar_rotation"]) + lidar2cam_t = ( + cam_info["sensor2lidar_translation"] @ lidar2cam_r.T + ) + lidar2cam_rt = np.eye(4) + lidar2cam_rt[:3, :3] = lidar2cam_r.T + lidar2cam_rt[3, :3] = -lidar2cam_t + intrinsic = copy.deepcopy(cam_info["cam_intrinsic"]) + cam_intrinsic.append(intrinsic) + viewpad = np.eye(4) + viewpad[: intrinsic.shape[0], : intrinsic.shape[1]] = intrinsic + lidar2img_rt = viewpad @ lidar2cam_rt.T + lidar2img_rts.append(lidar2img_rt) + lidar2cam_rts.append(lidar2cam_rt) + + input_dict.update( + dict( + img_filename=image_paths, + lidar2img=lidar2img_rts, + lidar2cam=lidar2cam_rts, + cam_intrinsic=cam_intrinsic, + ) + ) + + annos = self.get_ann_info(index) + input_dict.update(annos) + return input_dict + + def get_ann_info(self, index): + info = self.data_infos[index] + if self.use_gt_mask: + if self.use_valid_flag: + mask = info["valid_flag"] + else: + mask = info["num_lidar_pts"] > 0 + else: + mask = np.ones(len(info["gt_boxes"]), dtype=bool) + gt_bboxes_3d = info["gt_boxes"][mask] + gt_names_3d = info["gt_names"][mask] + gt_labels_3d = [] + for cat in gt_names_3d: + if cat in self.CLASSES: + gt_labels_3d.append(self.CLASSES.index(cat)) + else: + gt_labels_3d.append(-1) + gt_labels_3d = np.array(gt_labels_3d) + + if self.with_velocity: + gt_velocity = info["gt_velocity"][mask] + nan_mask = np.isnan(gt_velocity[:, 0]) + gt_velocity[nan_mask] = [0.0, 0.0] + gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1) + + if self.with_visibility: + gt_visibility = (info["num_lidar_pts"][mask] > 0).astype(np.float32) + + anns_results = dict( + gt_bboxes_3d=gt_bboxes_3d, + gt_labels_3d=gt_labels_3d, + gt_names=gt_names_3d, + gt_visibility=gt_visibility, + ) + if "instance_inds" in info: + instance_inds = np.array(info["instance_inds"], dtype=np.int)[mask] + anns_results["instance_inds"] = instance_inds + + if 'gt_agent_fut_trajs' in info: + anns_results['gt_agent_fut_trajs'] = info['gt_agent_fut_trajs'][mask] + anns_results['gt_agent_fut_masks'] = info['gt_agent_fut_masks'][mask] + + if 'gt_ego_fut_trajs' in info: + anns_results['gt_ego_fut_trajs'] = info['gt_ego_fut_trajs'] + anns_results['gt_ego_fut_masks'] = info['gt_ego_fut_masks'] + anns_results['gt_ego_fut_cmd'] = info['gt_ego_fut_cmd'] + + ## get future box for planning eval + fut_ts = int(info['gt_ego_fut_masks'].sum()) + fut_boxes = [] + fut_boxes_occluded = [] + cur_scene_token = info["scene_token"] + cur_T_global = get_T_global(info) + for i in range(1, fut_ts + 1): + fut_info = self.data_infos[index + i] + fut_scene_token = fut_info["scene_token"] + if cur_scene_token != fut_scene_token: + break + if self.use_gt_mask: + if self.use_valid_flag: + mask = fut_info["valid_flag"] + else: + mask = fut_info["num_lidar_pts"] > 0 + else: + mask = np.ones(len(fut_info["gt_boxes"]), dtype=bool) + if self.use_valid_flag: + occluded_mask = ~fut_info["valid_flag"] + else: + occluded_mask = ~(fut_info["num_lidar_pts"] > 0) + + fut_gt_bboxes_3d = fut_info["gt_boxes"][mask] + fut_gt_bboxes_occluded = fut_info["gt_boxes"][occluded_mask] + + fut_T_global = get_T_global(fut_info) + T_fut2cur = np.linalg.inv(cur_T_global) @ fut_T_global + + for bboxes in (fut_gt_bboxes_3d, fut_gt_bboxes_occluded): + if len(bboxes): + center = bboxes[:, :3] @ T_fut2cur[:3, :3].T + T_fut2cur[:3, 3] + yaw = np.stack([np.cos(bboxes[:, 6]), np.sin(bboxes[:, 6])], axis=-1) + yaw = yaw @ T_fut2cur[:2, :2].T + bboxes[:, :3] = center + bboxes[:, 6] = np.arctan2(yaw[..., 1], yaw[..., 0]) + + fut_boxes.append(fut_gt_bboxes_3d) + fut_boxes_occluded.append(fut_gt_bboxes_occluded) + + anns_results['fut_boxes'] = fut_boxes + anns_results['fut_boxes_occluded'] = fut_boxes_occluded + + return anns_results + + def _format_bbox(self, results, jsonfile_prefix=None, tracking=False): + nusc_annos = {} + mapped_class_names = self.CLASSES + + print("Start to convert detection format...") + for sample_id, det in enumerate(mmcv.track_iter_progress(results)): + annos = [] + boxes = output_to_nusc_box( + det, threshold=self.tracking_threshold if tracking else None + ) + sample_token = self.data_infos[sample_id]["token"] + boxes = lidar_nusc_box_to_global( + self.data_infos[sample_id], + boxes, + mapped_class_names, + self.det3d_eval_configs, + self.det3d_eval_version, + ) + for i, box in enumerate(boxes): + name = mapped_class_names[box.label] + if tracking and name in [ + "barrier", + "traffic_cone", + "construction_vehicle", + ]: + continue + if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2: + if name in [ + "car", + "construction_vehicle", + "bus", + "truck", + "trailer", + ]: + attr = "vehicle.moving" + elif name in ["bicycle", "motorcycle"]: + attr = "cycle.with_rider" + else: + attr = NuScenes3DDataset.DefaultAttribute[name] + else: + if name in ["pedestrian"]: + attr = "pedestrian.standing" + elif name in ["bus"]: + attr = "vehicle.stopped" + else: + attr = NuScenes3DDataset.DefaultAttribute[name] + + nusc_anno = dict( + sample_token=sample_token, + translation=box.center.tolist(), + size=box.wlh.tolist(), + rotation=box.orientation.elements.tolist(), + velocity=box.velocity[:2].tolist(), + ) + if not tracking: + nusc_anno.update( + dict( + detection_name=name, + detection_score=box.score, + attribute_name=attr, + ) + ) + else: + nusc_anno.update( + dict( + tracking_name=name, + tracking_score=box.score, + tracking_id=str(box.token), + ) + ) + + annos.append(nusc_anno) + nusc_annos[sample_token] = annos + nusc_submissions = { + "meta": self.modality, + "results": nusc_annos, + } + + mmcv.mkdir_or_exist(jsonfile_prefix) + filename = "results_nusc_tracking.json" if tracking else "results_nusc.json" + res_path = osp.join(jsonfile_prefix, filename) + print("Results writes to", res_path) + mmcv.dump(nusc_submissions, res_path) + return res_path + + def _evaluate_single( + self, result_path, logger=None, result_name="img_bbox", tracking=False + ): + from nuscenes import NuScenes + + output_dir = osp.join(*osp.split(result_path)[:-1]) + nusc = NuScenes( + version=self.version, dataroot=self.data_root, verbose=False + ) + eval_set_map = { + "v1.0-mini": "mini_val", + "v1.0-trainval": "val", + } + if not tracking: + from nuscenes.eval.detection.evaluate import NuScenesEval + + nusc_eval = NuScenesEval( + nusc, + config=self.det3d_eval_configs, + result_path=result_path, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=True, + ) + nusc_eval.main(render_curves=False) + + # record metrics + metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json")) + detail = dict() + metric_prefix = f"{result_name}_NuScenes" + for name in self.CLASSES: + for k, v in metrics["label_aps"][name].items(): + val = float("{:.4f}".format(v)) + detail[ + "{}/{}_AP_dist_{}".format(metric_prefix, name, k) + ] = val + for k, v in metrics["label_tp_errors"][name].items(): + val = float("{:.4f}".format(v)) + detail["{}/{}_{}".format(metric_prefix, name, k)] = val + for k, v in metrics["tp_errors"].items(): + val = float("{:.4f}".format(v)) + detail[ + "{}/{}".format(metric_prefix, self.ErrNameMapping[k]) + ] = val + + detail["{}/NDS".format(metric_prefix)] = metrics["nd_score"] + detail["{}/mAP".format(metric_prefix)] = metrics["mean_ap"] + else: + from nuscenes.eval.tracking.evaluate import TrackingEval + + nusc_eval = TrackingEval( + config=self.track3d_eval_configs, + result_path=result_path, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=True, + nusc_version=self.version, + nusc_dataroot=self.data_root, + ) + metrics = nusc_eval.main() + + # record metrics + metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json")) + print(metrics) + detail = dict() + metric_prefix = f"{result_name}_NuScenes" + keys = [ + "amota", + "amotp", + "recall", + "motar", + "gt", + "mota", + "motp", + "mt", + "ml", + "faf", + "tp", + "fp", + "fn", + "ids", + "frag", + "tid", + "lgd", + ] + for key in keys: + detail["{}/{}".format(metric_prefix, key)] = metrics[key] + + return detail + + def format_results(self, results, jsonfile_prefix=None, tracking=False): + assert isinstance(results, list), "results must be a list" + + if jsonfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + jsonfile_prefix = osp.join(tmp_dir.name, "results") + else: + tmp_dir = None + + if not ("pts_bbox" in results[0] or "img_bbox" in results[0]): + result_files = self._format_bbox( + results, jsonfile_prefix, tracking=tracking + ) + else: + result_files = dict() + for name in results[0]: + print(f"\nFormating bboxes of {name}") + results_ = [out[name] for out in results] + tmp_file_ = jsonfile_prefix + result_files.update( + { + name: self._format_bbox( + results_, tmp_file_, tracking=tracking + ) + } + ) + return result_files, tmp_dir + + def format_map_results(self, results, prefix=None): + submissions = {'results': {},} + + for j, pred in enumerate(results): + ''' + For each case, the result should be formatted as Dict{'vectors': [], 'scores': [], 'labels': []} + 'vectors': List of vector, each vector is a array([[x1, y1], [x2, y2] ...]), + contain all vectors predicted in this sample. + 'scores: List of score(float), + contain scores of all instances in this sample. + 'labels': List of label(int), + contain labels of all instances in this sample. + ''' + if pred is None: # empty prediction + continue + pred = pred['img_bbox'] + + single_case = {'vectors': [], 'scores': [], 'labels': []} + token = self.data_infos[j]['token'] + for i in range(len(pred['scores'])): + score = pred['scores'][i] + label = pred['labels'][i] + vector = pred['vectors'][i] + + # A line should have >=2 points + if len(vector) < 2: + continue + + single_case['vectors'].append(vector) + single_case['scores'].append(score) + single_case['labels'].append(label) + + submissions['results'][token] = single_case + + out_path = osp.join(prefix, 'submission_vector.json') + print(f'saving submissions results to {out_path}') + os.makedirs(os.path.dirname(out_path), exist_ok=True) + mmcv.dump(submissions, out_path) + return out_path + + def format_motion_results(self, results, jsonfile_prefix=None, tracking=False, thresh=None): + nusc_annos = {} + mapped_class_names = self.CLASSES + + print("Start to convert detection format...") + for sample_id, det in enumerate(mmcv.track_iter_progress(results)): + annos = [] + boxes = output_to_nusc_box( + det['img_bbox'], threshold=None + ) + sample_token = self.data_infos[sample_id]["token"] + boxes = lidar_nusc_box_to_global( + self.data_infos[sample_id], + boxes, + mapped_class_names, + self.det3d_eval_configs, + self.det3d_eval_version, + filter_with_cls_range=False, + ) + for i, box in enumerate(boxes): + if thresh is not None and box.score < thresh: + continue + name = mapped_class_names[box.label] + if tracking and name in [ + "barrier", + "traffic_cone", + "construction_vehicle", + ]: + continue + if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2: + if name in [ + "car", + "construction_vehicle", + "bus", + "truck", + "trailer", + ]: + attr = "vehicle.moving" + elif name in ["bicycle", "motorcycle"]: + attr = "cycle.with_rider" + else: + attr = NuScenes3DDataset.DefaultAttribute[name] + else: + if name in ["pedestrian"]: + attr = "pedestrian.standing" + elif name in ["bus"]: + attr = "vehicle.stopped" + else: + attr = NuScenes3DDataset.DefaultAttribute[name] + + nusc_anno = dict( + sample_token=sample_token, + translation=box.center.tolist(), + size=box.wlh.tolist(), + rotation=box.orientation.elements.tolist(), + velocity=box.velocity[:2].tolist(), + ) + if not tracking: + nusc_anno.update( + dict( + detection_name=name, + detection_score=box.score, + attribute_name=attr, + ) + ) + else: + nusc_anno.update( + dict( + tracking_name=name, + tracking_score=box.score, + tracking_id=str(box.token), + ) + ) + nusc_anno.update( + dict( + trajs=det['img_bbox']['trajs_3d'][i].numpy(), + ) + ) + if 'trajs_score' in det['img_bbox']: + nusc_anno['trajs_score'] = det['img_bbox']['trajs_score'][i].numpy() + annos.append(nusc_anno) + nusc_annos[sample_token] = annos + nusc_submissions = { + "meta": self.modality, + "results": nusc_annos, + } + + return nusc_submissions + + def _evaluate_single_motion(self, + results, + result_path, + logger=None, + metric='bbox', + result_name='pts_bbox'): + """Evaluation for a single model in nuScenes protocol. + + Args: + result_path (str): Path of the result file. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + metric (str): Metric name used for evaluation. Default: 'bbox'. + result_name (str): Result name in the metric prefix. + Default: 'pts_bbox'. + + Returns: + dict: Dictionary of evaluation details. + """ + from nuscenes import NuScenes + from .evaluation.motion.motion_eval_uniad import NuScenesEval as NuScenesEvalMotion + + output_dir = result_path + nusc = NuScenes( + version=self.version, dataroot=self.data_root, verbose=False) + eval_set_map = { + 'v1.0-mini': 'mini_val', + 'v1.0-trainval': 'val', + } + nusc_eval = NuScenesEvalMotion( + nusc, + config=copy.deepcopy(self.det3d_eval_configs), + result_path=results, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=False, + seconds=6) + metrics = nusc_eval.main(render_curves=False) + + MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err'] + class_names = ['car', 'pedestrian'] + + table = prettytable.PrettyTable() + table.field_names = ["class names"] + MOTION_METRICS + for class_name in class_names: + row_data = [class_name] + for m in MOTION_METRICS: + row_data.append('%.4f' % metrics[f'{class_name}_{m}']) + table.add_row(row_data) + print_log('\n'+str(table), logger=logger) + return metrics + + def _evaluate_single_det_occluded(self, result_path, logger=None, result_name='img_bbox'): + """Evaluate detection on occluded objects only (num_lidar_pts == 0).""" + from nuscenes import NuScenes + from .evaluation.det.occluded_det_eval import OccludedDetectionEval + + output_dir = osp.join(osp.dirname(result_path), 'occluded_det') + nusc = NuScenes(version=self.version, dataroot=self.data_root, verbose=False) + eval_set_map = { + 'v1.0-mini': 'mini_val', + 'v1.0-trainval': 'val', + } + nusc_eval = OccludedDetectionEval( + nusc, + config=self.det3d_eval_configs, + result_path=result_path, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=False, + ) + nusc_eval.main(render_curves=False) + + metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json')) + detail = {} + for name in self.CLASSES: + for k, v in metrics['label_aps'].get(name, {}).items(): + detail[f'occluded/{name}_AP_dist_{k}'] = float('{:.4f}'.format(v)) + for k, v in metrics['label_tp_errors'].get(name, {}).items(): + detail[f'occluded/{name}_{k}'] = float('{:.4f}'.format(v)) + for k, v in metrics['tp_errors'].items(): + detail[f'occluded/{self.ErrNameMapping[k]}'] = float('{:.4f}'.format(v)) + detail['occluded/NDS'] = metrics['nd_score'] + detail['occluded/mAP'] = metrics['mean_ap'] + return detail + + def _evaluate_single_motion_occluded(self, + results, + result_path, + logger=None): + """Evaluate motion prediction restricted to occluded objects (num_lidar_pts == 0).""" + from nuscenes import NuScenes + from .evaluation.motion.motion_eval_uniad import OccludedMotionEval + + output_dir = osp.join(result_path, 'occluded_motion') + nusc = NuScenes( + version=self.version, dataroot=self.data_root, verbose=False) + eval_set_map = { + 'v1.0-mini': 'mini_val', + 'v1.0-trainval': 'val', + } + nusc_eval = OccludedMotionEval( + nusc, + config=copy.deepcopy(self.det3d_eval_configs), + result_path=results, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=False, + seconds=6) + metrics = nusc_eval.main(render_curves=False) + + MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err'] + class_names = ['car', 'pedestrian'] + + table = prettytable.PrettyTable() + table.field_names = ["class names (occluded)"] + MOTION_METRICS + for class_name in class_names: + row_data = [class_name] + for m in MOTION_METRICS: + row_data.append('%.4f' % metrics[f'{class_name}_{m}']) + table.add_row(row_data) + print_log('\n[Occluded Objects]\n' + str(table), logger=logger) + + return {f'occluded/{k}': v for k, v in metrics.items()} + + def _evaluate_single_det_visible(self, result_path, logger=None, result_name='img_bbox'): + """Evaluate detection on visible objects, ignoring predictions that match occluded GT. + + This gives a fair vis/mAP comparison between a visible-only baseline and + a model trained to also predict occluded objects: detections of occluded + objects are not penalised as false positives. + """ + from nuscenes import NuScenes + from .evaluation.det.occluded_det_eval import VisibleDetectionEval + + output_dir = osp.join(osp.dirname(result_path), 'visible_det') + nusc = NuScenes(version=self.version, dataroot=self.data_root, verbose=False) + eval_set_map = { + 'v1.0-mini': 'mini_val', + 'v1.0-trainval': 'val', + } + nusc_eval = VisibleDetectionEval( + nusc, + config=self.det3d_eval_configs, + result_path=result_path, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=False, + ) + nusc_eval.main(render_curves=False) + + metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json')) + detail = {} + for name in self.CLASSES: + for k, v in metrics['label_aps'].get(name, {}).items(): + detail[f'vis/{name}_AP_dist_{k}'] = float('{:.4f}'.format(v)) + for k, v in metrics['label_tp_errors'].get(name, {}).items(): + detail[f'vis/{name}_{k}'] = float('{:.4f}'.format(v)) + for k, v in metrics['tp_errors'].items(): + detail[f'vis/{self.ErrNameMapping[k]}'] = float('{:.4f}'.format(v)) + detail['vis/NDS'] = metrics['nd_score'] + detail['vis/mAP'] = metrics['mean_ap'] + return detail + + def _evaluate_single_det_all(self, result_path, logger=None, result_name='img_bbox'): + """Evaluate detection on all objects (visible + occluded).""" + from nuscenes import NuScenes + from .evaluation.det.occluded_det_eval import AllDetectionEval + + output_dir = osp.join(osp.dirname(result_path), 'all_det') + nusc = NuScenes(version=self.version, dataroot=self.data_root, verbose=False) + eval_set_map = { + 'v1.0-mini': 'mini_val', + 'v1.0-trainval': 'val', + } + nusc_eval = AllDetectionEval( + nusc, + config=self.det3d_eval_configs, + result_path=result_path, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=False, + ) + nusc_eval.main(render_curves=False) + + metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json')) + detail = {} + for name in self.CLASSES: + for k, v in metrics['label_aps'].get(name, {}).items(): + detail[f'all/{name}_AP_dist_{k}'] = float('{:.4f}'.format(v)) + for k, v in metrics['label_tp_errors'].get(name, {}).items(): + detail[f'all/{name}_{k}'] = float('{:.4f}'.format(v)) + for k, v in metrics['tp_errors'].items(): + detail[f'all/{self.ErrNameMapping[k]}'] = float('{:.4f}'.format(v)) + detail['all/NDS'] = metrics['nd_score'] + detail['all/mAP'] = metrics['mean_ap'] + return detail + + def _evaluate_single_motion_all(self, results, result_path, logger=None): + """Evaluate motion prediction on all objects (visible + occluded).""" + from nuscenes import NuScenes + from .evaluation.motion.motion_eval_uniad import AllMotionEval + + output_dir = osp.join(result_path, 'all_motion') + nusc = NuScenes( + version=self.version, dataroot=self.data_root, verbose=False) + eval_set_map = { + 'v1.0-mini': 'mini_val', + 'v1.0-trainval': 'val', + } + nusc_eval = AllMotionEval( + nusc, + config=copy.deepcopy(self.det3d_eval_configs), + result_path=results, + eval_set=eval_set_map[self.version], + output_dir=output_dir, + verbose=False, + seconds=6) + metrics = nusc_eval.main(render_curves=False) + + MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err'] + class_names = ['car', 'pedestrian'] + + table = prettytable.PrettyTable() + table.field_names = ["class names (all)"] + MOTION_METRICS + for class_name in class_names: + row_data = [class_name] + for m in MOTION_METRICS: + row_data.append('%.4f' % metrics[f'{class_name}_{m}']) + table.add_row(row_data) + print_log('\n[All Objects]\n' + str(table), logger=logger) + + return {f'all/{k}': v for k, v in metrics.items()} + + def _evaluate_visibility_accuracy(self, results): + """Evaluate the visibility head calibration on matched GT boxes. + + For each GT box (visible or occluded) we find the closest same-class + prediction within MATCH_DIST metres (BEV). We then compare the + prediction's sigmoid visibility score to the GT sensor-visibility flag + (num_lidar_pts > 0). + + Returned metrics + ---------------- + visibility/accuracy : fraction correct at 0.5 threshold + visibility/auroc : area under the ROC curve + visibility/accuracy_visible : accuracy on visible-GT-matched pairs + visibility/accuracy_occluded : accuracy on occluded-GT-matched pairs + visibility/n_matched : total matched pairs across the val set + visibility/n_visible : matched pairs where GT is visible + visibility/n_occluded : matched pairs where GT is occluded + """ + MATCH_DIST = 4.0 # BEV centre-distance threshold in metres + + all_scores = [] # predicted visibility scores (sigmoid) + all_targets = [] # GT sensor-visibility (0.0 / 1.0) + + for i, result in enumerate(results): + det = result.get('img_bbox', result) + if 'visibility_scores' not in det: + return {} # head not enabled — skip entirely + + vis_scores = det['visibility_scores'].numpy() # (N,) + boxes = det['boxes_3d'].numpy() # (N, ≥2) + pred_labels = det['labels_3d'].numpy() # (N,) + + info = self.data_infos[i] + gt_boxes = info['gt_boxes'] # (M, 7) lidar frame + gt_names = info['gt_names'] # (M,) + + if 'num_lidar_pts' in info: + gt_vis = (info['num_lidar_pts'] > 0).astype(np.float32) + elif 'valid_flag' in info: + gt_vis = info['valid_flag'].astype(np.float32) + else: + continue + + if len(gt_boxes) == 0 or len(boxes) == 0: + continue + + pred_centers = boxes[:, :2] # (N, 2) BEV + gt_centers = gt_boxes[:, :2] # (M, 2) BEV + + for gi in range(len(gt_names)): + gt_cls = gt_names[gi] + if gt_cls not in self.CLASSES: + continue + gt_label = self.CLASSES.index(gt_cls) + + cls_idx = np.where(pred_labels == gt_label)[0] + if len(cls_idx) == 0: + continue + + dists = np.linalg.norm(pred_centers[cls_idx] - gt_centers[gi], axis=1) + nearest = dists.argmin() + if dists[nearest] <= MATCH_DIST: + all_scores.append(float(vis_scores[cls_idx[nearest]])) + all_targets.append(float(gt_vis[gi])) + + if len(all_targets) < 2: + return {} + + scores = np.array(all_scores, dtype=np.float64) + targets = np.array(all_targets, dtype=np.float64) + preds = (scores > 0.5).astype(np.float64) + + accuracy = float((preds == targets).mean()) + + # AUROC via trapezoidal rule (no external dependency) + pos = targets.sum() + neg = len(targets) - pos + if pos > 0 and neg > 0: + order = np.argsort(scores)[::-1] + t_sorted = targets[order] + tprs = np.concatenate([[0.0], np.cumsum(t_sorted == 1) / pos, [1.0]]) + fprs = np.concatenate([[0.0], np.cumsum(t_sorted == 0) / neg, [1.0]]) + auroc = float(np.trapz(tprs, fprs)) + else: + auroc = float('nan') + + vis_mask = targets == 1.0 + occ_mask = targets == 0.0 + acc_vis = float((preds[vis_mask] == targets[vis_mask]).mean()) if vis_mask.any() else float('nan') + acc_occ = float((preds[occ_mask] == targets[occ_mask]).mean()) if occ_mask.any() else float('nan') + + return { + 'visibility/accuracy': accuracy, + 'visibility/auroc': auroc, + 'visibility/accuracy_visible': acc_vis, + 'visibility/accuracy_occluded': acc_occ, + 'visibility/n_matched': int(len(targets)), + 'visibility/n_visible': int(vis_mask.sum()), + 'visibility/n_occluded': int(occ_mask.sum()), + } + + def evaluate( + self, + results, + eval_mode, + metric=None, + logger=None, + jsonfile_prefix=None, + result_names=["img_bbox"], + show=False, + out_dir=None, + pipeline=None, + ): + res_path = "results.pkl" if "trainval" in self.version else "results_mini.pkl" + res_path = osp.join(self.work_dir, res_path) + print('All Results write to', res_path) + mmcv.dump(results, res_path) + + results_dict = dict() + detection_result_files = None + if eval_mode['with_det']: + self.tracking = eval_mode["with_tracking"] + self.tracking_threshold = eval_mode["tracking_threshold"] + for metric in ["detection", "tracking"]: + tracking = metric == "tracking" + if tracking and not self.tracking: + continue + result_files, tmp_dir = self.format_results( + results, jsonfile_prefix=self.work_dir, tracking=tracking + ) + if not tracking: + detection_result_files = result_files + + if isinstance(result_files, dict): + for name in result_names: + ret_dict = self._evaluate_single( + result_files[name], tracking=tracking + ) + results_dict.update(ret_dict) + elif isinstance(result_files, str): + ret_dict = self._evaluate_single( + result_files, tracking=tracking + ) + results_dict.update(ret_dict) + + if tmp_dir is not None: + tmp_dir.cleanup() + + vis_metrics = self._evaluate_visibility_accuracy(results) + results_dict.update(vis_metrics) + + if eval_mode['with_map']: + from .evaluation.map.vector_eval import VectorEvaluate + self.map_evaluator = VectorEvaluate(self.eval_config) + result_path = self.format_map_results(results, prefix=self.work_dir) + map_results_dict = self.map_evaluator.evaluate(result_path, logger=logger) + results_dict.update(map_results_dict) + + motion_result_files = None + if eval_mode['with_motion']: + thresh = eval_mode["motion_threshhold"] + motion_result_files = self.format_motion_results(results, jsonfile_prefix=self.work_dir, thresh=thresh) + motion_results_dict = self._evaluate_single_motion(motion_result_files, self.work_dir, logger=logger) + results_dict.update(motion_results_dict) + + if eval_mode.get('with_occlusion', False): + thresh = eval_mode["motion_threshhold"] + if motion_result_files is None: + motion_result_files = self.format_motion_results(results, jsonfile_prefix=self.work_dir, thresh=thresh) + occluded_results_dict = self._evaluate_single_motion_occluded( + motion_result_files, self.work_dir, logger=logger) + results_dict.update(occluded_results_dict) + + if detection_result_files is not None: + if isinstance(detection_result_files, dict): + for name in result_names: + vis_det_dict = self._evaluate_single_det_visible( + detection_result_files[name], logger=logger, result_name=name) + results_dict.update(vis_det_dict) + occ_det_dict = self._evaluate_single_det_occluded( + detection_result_files[name], logger=logger, result_name=name) + results_dict.update(occ_det_dict) + all_det_dict = self._evaluate_single_det_all( + detection_result_files[name], logger=logger, result_name=name) + results_dict.update(all_det_dict) + elif isinstance(detection_result_files, str): + vis_det_dict = self._evaluate_single_det_visible( + detection_result_files, logger=logger) + results_dict.update(vis_det_dict) + occ_det_dict = self._evaluate_single_det_occluded( + detection_result_files, logger=logger) + results_dict.update(occ_det_dict) + all_det_dict = self._evaluate_single_det_all( + detection_result_files, logger=logger) + results_dict.update(all_det_dict) + + all_results_dict = self._evaluate_single_motion_all( + motion_result_files, self.work_dir, logger=logger) + results_dict.update(all_results_dict) + + if eval_mode['with_planning']: + from .evaluation.planning.planning_eval import planning_eval + planning_results_dict = planning_eval( + results, self.eval_config, logger=logger, + with_occlusion=eval_mode.get('with_occlusion', False)) + results_dict.update(planning_results_dict) + + if show or out_dir: + self.show(results, save_dir=out_dir, show=show, pipeline=pipeline) + + # print main metrics for recording + metric_str = '\n' + if "img_bbox_NuScenes/NDS" in results_dict: + metric_str += f'mAP: {results_dict.get("img_bbox_NuScenes/mAP"):.4f}\n' + metric_str += f'mATE: {results_dict.get("img_bbox_NuScenes/mATE"):.4f}\n' + metric_str += f'mASE: {results_dict.get("img_bbox_NuScenes/mASE"):.4f}\n' + metric_str += f'mAOE: {results_dict.get("img_bbox_NuScenes/mAOE"):.4f}\n' + metric_str += f'mAVE: {results_dict.get("img_bbox_NuScenes/mAVE"):.4f}\n' + metric_str += f'mAAE: {results_dict.get("img_bbox_NuScenes/mAAE"):.4f}\n' + metric_str += f'NDS: {results_dict.get("img_bbox_NuScenes/NDS"):.4f}\n\n' + + if "img_bbox_NuScenes/amota" in results_dict: + metric_str += f'AMOTA: {results_dict["img_bbox_NuScenes/amota"]:.4f}\n' + metric_str += f'AMOTP: {results_dict["img_bbox_NuScenes/amotp"]:.4f}\n' + metric_str += f'RECALL: {results_dict["img_bbox_NuScenes/recall"]:.4f}\n' + metric_str += f'MOTAR: {results_dict["img_bbox_NuScenes/motar"]:.4f}\n' + metric_str += f'MOTA: {results_dict["img_bbox_NuScenes/mota"]:.4f}\n' + metric_str += f'MOTP: {results_dict["img_bbox_NuScenes/motp"]:.4f}\n' + metric_str += f'IDS: {results_dict["img_bbox_NuScenes/ids"]}\n\n' + + if "mAP_normal" in results_dict: + metric_str += f'ped_crossing= {results_dict["ped_crossing"]:.4f}\n' + metric_str += f'divider= {results_dict["divider"]:.4f}\n' + metric_str += f'boundary= {results_dict["boundary"]:.4f}\n' + metric_str += f'mAP_normal= {results_dict["mAP_normal"]:.4f}\n\n' + + if "car_EPA" in results_dict: + metric_str += f'Car / Ped\n' + metric_str += f'epa= {results_dict["car_EPA"]:.4f} / {results_dict["pedestrian_EPA"]:.4f}\n' + metric_str += f'ade= {results_dict["car_min_ade_err"]:.4f} / {results_dict["pedestrian_min_ade_err"]:.4f}\n' + metric_str += f'fde= {results_dict["car_min_fde_err"]:.4f} / {results_dict["pedestrian_min_fde_err"]:.4f}\n' + metric_str += f'mr= {results_dict["car_miss_rate_err"]:.4f} / {results_dict["pedestrian_miss_rate_err"]:.4f}\n\n' + + if "occluded/NDS" in results_dict: + metric_str += f'[Occluded Det]\n' + metric_str += f'mAP: {results_dict["occluded/mAP"]:.4f}\n' + metric_str += f'mATE: {results_dict["occluded/mATE"]:.4f}\n' + metric_str += f'mASE: {results_dict["occluded/mASE"]:.4f}\n' + metric_str += f'mAOE: {results_dict["occluded/mAOE"]:.4f}\n' + metric_str += f'mAVE: {results_dict["occluded/mAVE"]:.4f}\n' + metric_str += f'mAAE: {results_dict["occluded/mAAE"]:.4f}\n' + metric_str += f'NDS: {results_dict["occluded/NDS"]:.4f}\n\n' + + if "occluded/car_EPA" in results_dict: + metric_str += f'[Occluded Motion] Car / Ped\n' + metric_str += f'epa= {results_dict["occluded/car_EPA"]:.4f} / {results_dict["occluded/pedestrian_EPA"]:.4f}\n' + metric_str += f'ade= {results_dict["occluded/car_min_ade_err"]:.4f} / {results_dict["occluded/pedestrian_min_ade_err"]:.4f}\n' + metric_str += f'fde= {results_dict["occluded/car_min_fde_err"]:.4f} / {results_dict["occluded/pedestrian_min_fde_err"]:.4f}\n' + metric_str += f'mr= {results_dict["occluded/car_miss_rate_err"]:.4f} / {results_dict["occluded/pedestrian_miss_rate_err"]:.4f}\n\n' + + if "all/NDS" in results_dict: + metric_str += f'[All Det]\n' + metric_str += f'mAP: {results_dict["all/mAP"]:.4f}\n' + metric_str += f'mATE: {results_dict["all/mATE"]:.4f}\n' + metric_str += f'mASE: {results_dict["all/mASE"]:.4f}\n' + metric_str += f'mAOE: {results_dict["all/mAOE"]:.4f}\n' + metric_str += f'mAVE: {results_dict["all/mAVE"]:.4f}\n' + metric_str += f'mAAE: {results_dict["all/mAAE"]:.4f}\n' + metric_str += f'NDS: {results_dict["all/NDS"]:.4f}\n\n' + + if "all/car_EPA" in results_dict: + metric_str += f'[All Motion] Car / Ped\n' + metric_str += f'epa= {results_dict["all/car_EPA"]:.4f} / {results_dict["all/pedestrian_EPA"]:.4f}\n' + metric_str += f'ade= {results_dict["all/car_min_ade_err"]:.4f} / {results_dict["all/pedestrian_min_ade_err"]:.4f}\n' + metric_str += f'fde= {results_dict["all/car_min_fde_err"]:.4f} / {results_dict["all/pedestrian_min_fde_err"]:.4f}\n' + metric_str += f'mr= {results_dict["all/car_miss_rate_err"]:.4f} / {results_dict["all/pedestrian_miss_rate_err"]:.4f}\n\n' + + if 'visibility/accuracy' in results_dict: + rd = results_dict + metric_str += f'[Visibility Head]\n' + metric_str += ( + f'accuracy= {rd["visibility/accuracy"]:.4f} ' + f'(visible={rd["visibility/accuracy_visible"]:.4f}, ' + f'occluded={rd["visibility/accuracy_occluded"]:.4f})\n' + ) + metric_str += f'auroc= {rd["visibility/auroc"]:.4f}\n' + metric_str += ( + f'matched: {rd["visibility/n_matched"]} ' + f'(visible={rd["visibility/n_visible"]}, ' + f'occluded={rd["visibility/n_occluded"]})\n\n' + ) + + if "L2" in results_dict: + metric_str += f'obj_box_col: {(results_dict["obj_box_col"]*100):.3f}%\n' + metric_str += f'L2: {results_dict["L2"]:.4f}\n' + if "occluded/obj_box_col" in results_dict: + metric_str += f'obj_box_col_occluded: {(results_dict["occluded/obj_box_col"]*100):.3f}%\n' + metric_str += '\n' + + print_log(metric_str, logger=logger) + return results_dict + + def show(self, results, save_dir=None, show=False, pipeline=None): + save_dir = "./" if save_dir is None else save_dir + save_dir = os.path.join(save_dir, "visual") + print_log(os.path.abspath(save_dir)) + pipeline = Compose(pipeline) + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + fourcc = cv2.VideoWriter_fourcc(*"MJPG") + videoWriter = None + + for i, result in enumerate(results): + if "img_bbox" in result.keys(): + result = result["img_bbox"] + data_info = pipeline(self.get_data_info(i)) + imgs = [] + + raw_imgs = data_info["img"] + lidar2img = data_info["img_metas"].data["lidar2img"] + pred_bboxes_3d = result["boxes_3d"][ + result["scores_3d"] > self.vis_score_threshold + ] + if "instance_ids" in result and self.tracking: + color = [] + for id in result["instance_ids"].cpu().numpy().tolist(): + color.append( + self.ID_COLOR_MAP[int(id % len(self.ID_COLOR_MAP))] + ) + elif "labels_3d" in result: + color = [] + for id in result["labels_3d"].cpu().numpy().tolist(): + color.append(self.ID_COLOR_MAP[id]) + else: + color = (255, 0, 0) + + # ===== draw boxes_3d to images ===== + for j, img_origin in enumerate(raw_imgs): + img = img_origin.copy() + if len(pred_bboxes_3d) != 0: + img = draw_lidar_bbox3d_on_img( + pred_bboxes_3d, + img, + lidar2img[j], + img_metas=None, + color=color, + thickness=3, + ) + imgs.append(img) + + # ===== draw boxes_3d to BEV ===== + bev = draw_lidar_bbox3d_on_bev( + pred_bboxes_3d, + bev_size=img.shape[0] * 2, + color=color, + ) + + # ===== put text and concat ===== + for j, name in enumerate( + [ + "front", + "front right", + "front left", + "rear", + "rear left", + "rear right", + ] + ): + imgs[j] = cv2.rectangle( + imgs[j], + (0, 0), + (440, 80), + color=(255, 255, 255), + thickness=-1, + ) + w, h = cv2.getTextSize(name, cv2.FONT_HERSHEY_SIMPLEX, 2, 2)[0] + text_x = int(220 - w / 2) + text_y = int(40 + h / 2) + + imgs[j] = cv2.putText( + imgs[j], + name, + (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, + 2, + (0, 0, 0), + 2, + cv2.LINE_AA, + ) + image = np.concatenate( + [ + np.concatenate([imgs[2], imgs[0], imgs[1]], axis=1), + np.concatenate([imgs[5], imgs[3], imgs[4]], axis=1), + ], + axis=0, + ) + image = np.concatenate([image, bev], axis=1) + + # ===== save video ===== + if videoWriter is None: + videoWriter = cv2.VideoWriter( + os.path.join(save_dir, "video.avi"), + fourcc, + 7, + image.shape[:2][::-1], + ) + cv2.imwrite(os.path.join(save_dir, f"{i}.jpg"), image) + videoWriter.write(image) + videoWriter.release() + + +def output_to_nusc_box(detection, threshold=None): + box3d = detection["boxes_3d"] + scores = detection["scores_3d"].numpy() + labels = detection["labels_3d"].numpy() + if "instance_ids" in detection: + ids = detection["instance_ids"] # .numpy() + if threshold is not None: + if "cls_scores" in detection: + mask = detection["cls_scores"].numpy() >= threshold + else: + mask = scores >= threshold + box3d = box3d[mask] + scores = scores[mask] + labels = labels[mask] + ids = ids[mask] + + if hasattr(box3d, "gravity_center"): + box_gravity_center = box3d.gravity_center.numpy() + box_dims = box3d.dims.numpy() + nus_box_dims = box_dims[:, [1, 0, 2]] + box_yaw = box3d.yaw.numpy() + else: + box3d = box3d.numpy() + box_gravity_center = box3d[..., :3].copy() + box_dims = box3d[..., 3:6].copy() + nus_box_dims = box_dims[..., [1, 0, 2]] + box_yaw = box3d[..., 6].copy() + + # TODO: check whether this is necessary + # with dir_offset & dir_limit in the head + # box_yaw = -box_yaw - np.pi / 2 + + box_list = [] + for i in range(len(box3d)): + quat = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw[i]) + if hasattr(box3d, "gravity_center"): + velocity = (*box3d.tensor[i, 7:9], 0.0) + else: + velocity = (*box3d[i, 7:9], 0.0) + box = NuScenesBox( + box_gravity_center[i], + nus_box_dims[i], + quat, + label=labels[i], + score=scores[i], + velocity=velocity, + ) + if "instance_ids" in detection: + box.token = ids[i] + box_list.append(box) + return box_list + + +def lidar_nusc_box_to_global( + info, + boxes, + classes, + eval_configs, + eval_version="detection_cvpr_2019", + filter_with_cls_range=True, +): + box_list = [] + for i, box in enumerate(boxes): + # Move box to ego vehicle coord system + box.rotate(pyquaternion.Quaternion(info["lidar2ego_rotation"])) + box.translate(np.array(info["lidar2ego_translation"])) + # filter det in ego. + if filter_with_cls_range: + cls_range_map = eval_configs.class_range + radius = np.linalg.norm(box.center[:2], 2) + det_range = cls_range_map[classes[box.label]] + if radius > det_range: + continue + # Move box to global coord system + box.rotate(pyquaternion.Quaternion(info["ego2global_rotation"])) + box.translate(np.array(info["ego2global_translation"])) + box_list.append(box) + return box_list + + +def get_T_global(info): + lidar2ego = np.eye(4) + lidar2ego[:3, :3] = pyquaternion.Quaternion( + info["lidar2ego_rotation"] + ).rotation_matrix + lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"]) + ego2global = np.eye(4) + ego2global[:3, :3] = pyquaternion.Quaternion( + info["ego2global_rotation"] + ).rotation_matrix + ego2global[:3, 3] = np.array(info["ego2global_translation"]) + return ego2global @ lidar2ego \ No newline at end of file diff --git a/projects/mmdet3d_plugin/datasets/pipelines/__init__.py b/projects/mmdet3d_plugin/datasets/pipelines/__init__.py new file mode 100644 index 0000000..ce8bec5 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/pipelines/__init__.py @@ -0,0 +1,28 @@ +from .transform import ( + InstanceNameFilter, + CircleObjectRangeFilter, + NormalizeMultiviewImage, + NuScenesSparse4DAdaptor, + MultiScaleDepthMapGenerator, +) +from .augment import ( + ResizeCropFlipImage, + BBoxRotation, + PhotoMetricDistortionMultiViewImage, +) +from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile +from .vectorize import VectorizeMap + +__all__ = [ + "InstanceNameFilter", + "ResizeCropFlipImage", + "BBoxRotation", + "CircleObjectRangeFilter", + "MultiScaleDepthMapGenerator", + "NormalizeMultiviewImage", + "PhotoMetricDistortionMultiViewImage", + "NuScenesSparse4DAdaptor", + "LoadMultiViewImageFromFiles", + "LoadPointsFromFile", + "VectorizeMap", +] diff --git a/projects/mmdet3d_plugin/datasets/pipelines/augment.py b/projects/mmdet3d_plugin/datasets/pipelines/augment.py new file mode 100644 index 0000000..ff99e7e --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/pipelines/augment.py @@ -0,0 +1,233 @@ +import torch + +import numpy as np +from numpy import random +import mmcv +from mmdet.datasets.builder import PIPELINES +from PIL import Image + + +@PIPELINES.register_module() +class ResizeCropFlipImage(object): + def __call__(self, results): + aug_config = results.get("aug_config") + if aug_config is None: + return results + imgs = results["img"] + N = len(imgs) + new_imgs = [] + for i in range(N): + img, mat = self._img_transform( + np.uint8(imgs[i]), aug_config, + ) + new_imgs.append(np.array(img).astype(np.float32)) + results["lidar2img"][i] = mat @ results["lidar2img"][i] + if "cam_intrinsic" in results: + results["cam_intrinsic"][i][:3, :3] *= aug_config["resize"] + # results["cam_intrinsic"][i][:3, :3] = ( + # mat[:3, :3] @ results["cam_intrinsic"][i][:3, :3] + # ) + + results["img"] = new_imgs + results["img_shape"] = [x.shape[:2] for x in new_imgs] + return results + + def _img_transform(self, img, aug_configs): + H, W = img.shape[:2] + resize = aug_configs.get("resize", 1) + resize_dims = (int(W * resize), int(H * resize)) + crop = aug_configs.get("crop", [0, 0, *resize_dims]) + flip = aug_configs.get("flip", False) + rotate = aug_configs.get("rotate", 0) + + origin_dtype = img.dtype + if origin_dtype != np.uint8: + min_value = img.min() + max_vaule = img.max() + scale = 255 / (max_vaule - min_value) + img = (img - min_value) * scale + img = np.uint8(img) + img = Image.fromarray(img) + img = img.resize(resize_dims).crop(crop) + if flip: + img = img.transpose(method=Image.FLIP_LEFT_RIGHT) + img = img.rotate(rotate) + img = np.array(img).astype(np.float32) + if origin_dtype != np.uint8: + img = img.astype(np.float32) + img = img / scale + min_value + + transform_matrix = np.eye(3) + transform_matrix[:2, :2] *= resize + transform_matrix[:2, 2] -= np.array(crop[:2]) + if flip: + flip_matrix = np.array( + [[-1, 0, crop[2] - crop[0]], [0, 1, 0], [0, 0, 1]] + ) + transform_matrix = flip_matrix @ transform_matrix + rotate = rotate / 180 * np.pi + rot_matrix = np.array( + [ + [np.cos(rotate), np.sin(rotate), 0], + [-np.sin(rotate), np.cos(rotate), 0], + [0, 0, 1], + ] + ) + rot_center = np.array([crop[2] - crop[0], crop[3] - crop[1]]) / 2 + rot_matrix[:2, 2] = -rot_matrix[:2, :2] @ rot_center + rot_center + transform_matrix = rot_matrix @ transform_matrix + extend_matrix = np.eye(4) + extend_matrix[:3, :3] = transform_matrix + return img, extend_matrix + + +@PIPELINES.register_module() +class BBoxRotation(object): + def __call__(self, results): + angle = results["aug_config"]["rotate_3d"] + rot_cos = np.cos(angle) + rot_sin = np.sin(angle) + + rot_mat = np.array( + [ + [rot_cos, -rot_sin, 0, 0], + [rot_sin, rot_cos, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ] + ) + rot_mat_inv = np.linalg.inv(rot_mat) + + num_view = len(results["lidar2img"]) + for view in range(num_view): + results["lidar2img"][view] = ( + results["lidar2img"][view] @ rot_mat_inv + ) + if "lidar2global" in results: + results["lidar2global"] = results["lidar2global"] @ rot_mat_inv + if "gt_bboxes_3d" in results: + results["gt_bboxes_3d"] = self.box_rotate( + results["gt_bboxes_3d"], angle + ) + return results + + @staticmethod + def box_rotate(bbox_3d, angle): + rot_cos = np.cos(angle) + rot_sin = np.sin(angle) + rot_mat_T = np.array( + [[rot_cos, rot_sin, 0], [-rot_sin, rot_cos, 0], [0, 0, 1]] + ) + bbox_3d[:, :3] = bbox_3d[:, :3] @ rot_mat_T + bbox_3d[:, 6] += angle + if bbox_3d.shape[-1] > 7: + vel_dims = bbox_3d[:, 7:].shape[-1] + bbox_3d[:, 7:] = bbox_3d[:, 7:] @ rot_mat_T[:vel_dims, :vel_dims] + return bbox_3d + + +@PIPELINES.register_module() +class PhotoMetricDistortionMultiViewImage: + """Apply photometric distortion to image sequentially, every transformation + is applied with a probability of 0.5. The position of random contrast is in + second or second to last. + 1. random brightness + 2. random contrast (mode 0) + 3. convert color from BGR to HSV + 4. random saturation + 5. random hue + 6. convert color from HSV to BGR + 7. random contrast (mode 1) + 8. randomly swap channels + Args: + brightness_delta (int): delta of brightness. + contrast_range (tuple): range of contrast. + saturation_range (tuple): range of saturation. + hue_delta (int): delta of hue. + """ + + def __init__( + self, + brightness_delta=32, + contrast_range=(0.5, 1.5), + saturation_range=(0.5, 1.5), + hue_delta=18, + ): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def __call__(self, results): + """Call function to perform photometric distortion on images. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Result dict with images distorted. + """ + imgs = results["img"] + new_imgs = [] + for img in imgs: + assert img.dtype == np.float32, ( + "PhotoMetricDistortion needs the input image of dtype np.float32," + ' please set "to_float32=True" in "LoadImageFromFile" pipeline' + ) + # random brightness + if random.randint(2): + delta = random.uniform( + -self.brightness_delta, self.brightness_delta + ) + img += delta + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + if random.randint(2): + alpha = random.uniform( + self.contrast_lower, self.contrast_upper + ) + img *= alpha + + # convert color from BGR to HSV + img = mmcv.bgr2hsv(img) + + # random saturation + if random.randint(2): + img[..., 1] *= random.uniform( + self.saturation_lower, self.saturation_upper + ) + + # random hue + if random.randint(2): + img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + + # convert color from HSV to BGR + img = mmcv.hsv2bgr(img) + + # random contrast + if mode == 0: + if random.randint(2): + alpha = random.uniform( + self.contrast_lower, self.contrast_upper + ) + img *= alpha + + # randomly swap channels + if random.randint(2): + img = img[..., random.permutation(3)] + new_imgs.append(img) + results["img"] = new_imgs + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(\nbrightness_delta={self.brightness_delta},\n" + repr_str += "contrast_range=" + repr_str += f"{(self.contrast_lower, self.contrast_upper)},\n" + repr_str += "saturation_range=" + repr_str += f"{(self.saturation_lower, self.saturation_upper)},\n" + repr_str += f"hue_delta={self.hue_delta})" + return repr_str diff --git a/projects/mmdet3d_plugin/datasets/pipelines/loading.py b/projects/mmdet3d_plugin/datasets/pipelines/loading.py new file mode 100644 index 0000000..cb743ec --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/pipelines/loading.py @@ -0,0 +1,188 @@ +import numpy as np +import mmcv +from mmdet.datasets.builder import PIPELINES + + +@PIPELINES.register_module() +class LoadMultiViewImageFromFiles(object): + """Load multi channel images from a list of separate channel files. + + Expects results['img_filename'] to be a list of filenames. + + Args: + to_float32 (bool, optional): Whether to convert the img to float32. + Defaults to False. + color_type (str, optional): Color type of the file. + Defaults to 'unchanged'. + """ + + def __init__(self, to_float32=False, color_type="unchanged"): + self.to_float32 = to_float32 + self.color_type = color_type + + def __call__(self, results): + """Call function to load multi-view image from files. + + Args: + results (dict): Result dict containing multi-view image filenames. + + Returns: + dict: The result dict containing the multi-view image data. + Added keys and values are described below. + + - filename (str): Multi-view image filenames. + - img (np.ndarray): Multi-view image arrays. + - img_shape (tuple[int]): Shape of multi-view image arrays. + - ori_shape (tuple[int]): Shape of original image arrays. + - pad_shape (tuple[int]): Shape of padded image arrays. + - scale_factor (float): Scale factor. + - img_norm_cfg (dict): Normalization configuration of images. + """ + filename = results["img_filename"] + # img is of shape (h, w, c, num_views) + img = np.stack( + [mmcv.imread(name, self.color_type) for name in filename], axis=-1 + ) + if self.to_float32: + img = img.astype(np.float32) + results["filename"] = filename + # unravel to list, see `DefaultFormatBundle` in formatting.py + # which will transpose each image separately and then stack into array + results["img"] = [img[..., i] for i in range(img.shape[-1])] + results["img_shape"] = img.shape + results["ori_shape"] = img.shape + # Set initial values for default meta_keys + results["pad_shape"] = img.shape + results["scale_factor"] = 1.0 + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results["img_norm_cfg"] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False, + ) + return results + + def __repr__(self): + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f"(to_float32={self.to_float32}, " + repr_str += f"color_type='{self.color_type}')" + return repr_str + + +@PIPELINES.register_module() +class LoadPointsFromFile(object): + """Load Points From File. + + Load points from file. + + Args: + coord_type (str): The type of coordinates of points cloud. + Available options includes: + - 'LIDAR': Points in LiDAR coordinates. + - 'DEPTH': Points in depth coordinates, usually for indoor dataset. + - 'CAMERA': Points in camera coordinates. + load_dim (int, optional): The dimension of the loaded points. + Defaults to 6. + use_dim (list[int], optional): Which dimensions of the points to use. + Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 + or use_dim=[0, 1, 2, 3] to use the intensity dimension. + shift_height (bool, optional): Whether to use shifted height. + Defaults to False. + use_color (bool, optional): Whether to use color features. + Defaults to False. + file_client_args (dict, optional): Config dict of file clients, + refer to + https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py + for more details. Defaults to dict(backend='disk'). + """ + + def __init__( + self, + coord_type, + load_dim=6, + use_dim=[0, 1, 2], + shift_height=False, + use_color=False, + file_client_args=dict(backend="disk"), + ): + self.shift_height = shift_height + self.use_color = use_color + if isinstance(use_dim, int): + use_dim = list(range(use_dim)) + assert ( + max(use_dim) < load_dim + ), f"Expect all used dimensions < {load_dim}, got {use_dim}" + assert coord_type in ["CAMERA", "LIDAR", "DEPTH"] + + self.coord_type = coord_type + self.load_dim = load_dim + self.use_dim = use_dim + self.file_client_args = file_client_args.copy() + self.file_client = None + + def _load_points(self, pts_filename): + """Private function to load point clouds data. + + Args: + pts_filename (str): Filename of point clouds data. + + Returns: + np.ndarray: An array containing point clouds data. + """ + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + try: + pts_bytes = self.file_client.get(pts_filename) + points = np.frombuffer(pts_bytes, dtype=np.float32) + except ConnectionError: + mmcv.check_file_exist(pts_filename) + if pts_filename.endswith(".npy"): + points = np.load(pts_filename) + else: + points = np.fromfile(pts_filename, dtype=np.float32) + + return points + + def __call__(self, results): + """Call function to load points data from file. + + Args: + results (dict): Result dict containing point clouds data. + + Returns: + dict: The result dict containing the point clouds data. + Added key and value are described below. + + - points (:obj:`BasePoints`): Point clouds data. + """ + pts_filename = results["pts_filename"] + points = self._load_points(pts_filename) + points = points.reshape(-1, self.load_dim) + points = points[:, self.use_dim] + attribute_dims = None + + if self.shift_height: + floor_height = np.percentile(points[:, 2], 0.99) + height = points[:, 2] - floor_height + points = np.concatenate( + [points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1 + ) + attribute_dims = dict(height=3) + + if self.use_color: + assert len(self.use_dim) >= 6 + if attribute_dims is None: + attribute_dims = dict() + attribute_dims.update( + dict( + color=[ + points.shape[1] - 3, + points.shape[1] - 2, + points.shape[1] - 1, + ] + ) + ) + + results["points"] = points + return results diff --git a/projects/mmdet3d_plugin/datasets/pipelines/transform.py b/projects/mmdet3d_plugin/datasets/pipelines/transform.py new file mode 100644 index 0000000..4fe79bf --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/pipelines/transform.py @@ -0,0 +1,250 @@ +import numpy as np +import mmcv +from mmcv.parallel import DataContainer as DC +from mmdet.datasets.builder import PIPELINES +from mmdet.datasets.pipelines import to_tensor + + +@PIPELINES.register_module() +class MultiScaleDepthMapGenerator(object): + def __init__(self, downsample=1, max_depth=60): + if not isinstance(downsample, (list, tuple)): + downsample = [downsample] + self.downsample = downsample + self.max_depth = max_depth + + def __call__(self, input_dict): + points = input_dict["points"][..., :3, None] + gt_depth = [] + for i, lidar2img in enumerate(input_dict["lidar2img"]): + H, W = input_dict["img_shape"][i][:2] + + pts_2d = ( + np.squeeze(lidar2img[:3, :3] @ points, axis=-1) + + lidar2img[:3, 3] + ) + pts_2d[:, :2] /= pts_2d[:, 2:3] + U = np.round(pts_2d[:, 0]).astype(np.int32) + V = np.round(pts_2d[:, 1]).astype(np.int32) + depths = pts_2d[:, 2] + mask = np.logical_and.reduce( + [ + V >= 0, + V < H, + U >= 0, + U < W, + depths >= 0.1, + # depths <= self.max_depth, + ] + ) + V, U, depths = V[mask], U[mask], depths[mask] + sort_idx = np.argsort(depths)[::-1] + V, U, depths = V[sort_idx], U[sort_idx], depths[sort_idx] + depths = np.clip(depths, 0.1, self.max_depth) + for j, downsample in enumerate(self.downsample): + if len(gt_depth) < j + 1: + gt_depth.append([]) + h, w = (int(H / downsample), int(W / downsample)) + u = np.floor(U / downsample).astype(np.int32) + v = np.floor(V / downsample).astype(np.int32) + depth_map = np.ones([h, w], dtype=np.float32) * -1 + depth_map[v, u] = depths + gt_depth[j].append(depth_map) + + input_dict["gt_depth"] = [np.stack(x) for x in gt_depth] + return input_dict + + +@PIPELINES.register_module() +class NuScenesSparse4DAdaptor(object): + def __init(self): + pass + + def __call__(self, input_dict): + input_dict["projection_mat"] = np.float32( + np.stack(input_dict["lidar2img"]) + ) + input_dict["image_wh"] = np.ascontiguousarray( + np.array(input_dict["img_shape"], dtype=np.float32)[:, :2][:, ::-1] + ) + input_dict["T_global_inv"] = np.linalg.inv(input_dict["lidar2global"]) + input_dict["T_global"] = input_dict["lidar2global"] + if "cam_intrinsic" in input_dict: + input_dict["cam_intrinsic"] = np.float32( + np.stack(input_dict["cam_intrinsic"]) + ) + input_dict["focal"] = input_dict["cam_intrinsic"][..., 0, 0] + if "instance_inds" in input_dict: + input_dict["instance_id"] = input_dict["instance_inds"] + + if "gt_bboxes_3d" in input_dict: + input_dict["gt_bboxes_3d"][:, 6] = self.limit_period( + input_dict["gt_bboxes_3d"][:, 6], offset=0.5, period=2 * np.pi + ) + input_dict["gt_bboxes_3d"] = DC( + to_tensor(input_dict["gt_bboxes_3d"]).float() + ) + if "gt_labels_3d" in input_dict: + input_dict["gt_labels_3d"] = DC( + to_tensor(input_dict["gt_labels_3d"]).long() + ) + if "gt_visibility" in input_dict: + input_dict["gt_visibility"] = DC( + to_tensor(input_dict["gt_visibility"]).float() + ) + + imgs = [img.transpose(2, 0, 1) for img in input_dict["img"]] + imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) + input_dict["img"] = DC(to_tensor(imgs), stack=True) + + for key in [ + 'gt_map_labels', + 'gt_map_pts', + 'gt_agent_fut_trajs', + 'gt_agent_fut_masks', + ]: + if key not in input_dict: + continue + input_dict[key] = DC(to_tensor(input_dict[key]), stack=False, cpu_only=False) + + for key in [ + 'gt_ego_fut_trajs', + 'gt_ego_fut_masks', + 'gt_ego_fut_cmd', + 'ego_status', + ]: + if key not in input_dict: + continue + input_dict[key] = DC(to_tensor(input_dict[key]), stack=True, cpu_only=False, pad_dims=None) + + return input_dict + + def limit_period( + self, val: np.ndarray, offset: float = 0.5, period: float = np.pi + ) -> np.ndarray: + limited_val = val - np.floor(val / period + offset) * period + return limited_val + + +@PIPELINES.register_module() +class InstanceNameFilter(object): + """Filter GT objects by their names. + + Args: + classes (list[str]): List of class names to be kept for training. + """ + + def __init__(self, classes): + self.classes = classes + self.labels = list(range(len(self.classes))) + + def __call__(self, input_dict): + """Call function to filter objects by their names. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ + keys are updated in the result dict. + """ + gt_labels_3d = input_dict["gt_labels_3d"] + gt_bboxes_mask = np.array( + [n in self.labels for n in gt_labels_3d], dtype=np.bool_ + ) + input_dict["gt_bboxes_3d"] = input_dict["gt_bboxes_3d"][gt_bboxes_mask] + input_dict["gt_labels_3d"] = input_dict["gt_labels_3d"][gt_bboxes_mask] + if "instance_inds" in input_dict: + input_dict["instance_inds"] = input_dict["instance_inds"][gt_bboxes_mask] + if "gt_agent_fut_trajs" in input_dict: + input_dict["gt_agent_fut_trajs"] = input_dict["gt_agent_fut_trajs"][gt_bboxes_mask] + input_dict["gt_agent_fut_masks"] = input_dict["gt_agent_fut_masks"][gt_bboxes_mask] + if "gt_visibility" in input_dict: + input_dict["gt_visibility"] = input_dict["gt_visibility"][gt_bboxes_mask] + return input_dict + + def __repr__(self): + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f"(classes={self.classes})" + return repr_str + + +@PIPELINES.register_module() +class CircleObjectRangeFilter(object): + def __init__( + self, class_dist_thred=[52.5] * 5 + [31.5] + [42] * 3 + [31.5] + ): + self.class_dist_thred = class_dist_thred + + def __call__(self, input_dict): + gt_bboxes_3d = input_dict["gt_bboxes_3d"] + gt_labels_3d = input_dict["gt_labels_3d"] + dist = np.sqrt( + np.sum(gt_bboxes_3d[:, :2] ** 2, axis=-1) + ) + mask = np.array([False] * len(dist)) + for label_idx, dist_thred in enumerate(self.class_dist_thred): + mask = np.logical_or( + mask, + np.logical_and(gt_labels_3d == label_idx, dist <= dist_thred), + ) + + gt_bboxes_3d = gt_bboxes_3d[mask] + gt_labels_3d = gt_labels_3d[mask] + + input_dict["gt_bboxes_3d"] = gt_bboxes_3d + input_dict["gt_labels_3d"] = gt_labels_3d + if "instance_inds" in input_dict: + input_dict["instance_inds"] = input_dict["instance_inds"][mask] + if "gt_agent_fut_trajs" in input_dict: + input_dict["gt_agent_fut_trajs"] = input_dict["gt_agent_fut_trajs"][mask] + input_dict["gt_agent_fut_masks"] = input_dict["gt_agent_fut_masks"][mask] + if "gt_visibility" in input_dict: + input_dict["gt_visibility"] = input_dict["gt_visibility"][mask] + return input_dict + + def __repr__(self): + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f"(class_dist_thred={self.class_dist_thred})" + return repr_str + + +@PIPELINES.register_module() +class NormalizeMultiviewImage(object): + """Normalize the image. + Added key is "img_norm_cfg". + Args: + mean (sequence): Mean values of 3 channels. + std (sequence): Std values of 3 channels. + to_rgb (bool): Whether to convert the image from BGR to RGB, + default is true. + """ + + def __init__(self, mean, std, to_rgb=True): + self.mean = np.array(mean, dtype=np.float32) + self.std = np.array(std, dtype=np.float32) + self.to_rgb = to_rgb + + def __call__(self, results): + """Call function to normalize images. + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Normalized results, 'img_norm_cfg' key is added into + result dict. + """ + results["img"] = [ + mmcv.imnormalize(img, self.mean, self.std, self.to_rgb) + for img in results["img"] + ] + results["img_norm_cfg"] = dict( + mean=self.mean, std=self.std, to_rgb=self.to_rgb + ) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})" + return repr_str diff --git a/projects/mmdet3d_plugin/datasets/pipelines/vectorize.py b/projects/mmdet3d_plugin/datasets/pipelines/vectorize.py new file mode 100644 index 0000000..39d8729 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/pipelines/vectorize.py @@ -0,0 +1,208 @@ +from typing import List, Tuple, Union, Dict + +import numpy as np +from shapely.geometry import LineString +from numpy.typing import NDArray + +from mmcv.parallel import DataContainer as DC +from mmdet.datasets.builder import PIPELINES + + +@PIPELINES.register_module(force=True) +class VectorizeMap(object): + """Generate vectoized map and put into `semantic_mask` key. + Concretely, shapely geometry objects are converted into sample points (ndarray). + We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method. + + Args: + roi_size (tuple or list): bev range . + normalize (bool): whether to normalize points to range (0, 1). + coords_dim (int): dimension of point coordinates. + simplify (bool): whether to use simpily function. If true, `sample_num` \ + and `sample_dist` will be ignored. + sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore. + sample_dist (float): interpolate distance. Set to -1 to ignore. + """ + + def __init__(self, + roi_size: Union[Tuple, List], + normalize: bool, + coords_dim: int=2, + simplify: bool=False, + sample_num: int=-1, + sample_dist: float=-1, + permute: bool=False + ): + self.coords_dim = coords_dim + self.sample_num = sample_num + self.sample_dist = sample_dist + self.roi_size = np.array(roi_size) + self.normalize = normalize + self.simplify = simplify + self.permute = permute + + if sample_dist > 0: + assert sample_num < 0 and not simplify + self.sample_fn = self.interp_fixed_dist + elif sample_num > 0: + assert sample_dist < 0 and not simplify + self.sample_fn = self.interp_fixed_num + else: + assert simplify + + def interp_fixed_num(self, line: LineString) -> NDArray: + ''' Interpolate a line to fixed number of points. + + Args: + line (LineString): line + + Returns: + points (array): interpolated points, shape (N, 2) + ''' + + distances = np.linspace(0, line.length, self.sample_num) + sampled_points = np.array([list(line.interpolate(distance).coords) + for distance in distances]).squeeze() + + return sampled_points + + def interp_fixed_dist(self, line: LineString) -> NDArray: + ''' Interpolate a line at fixed interval. + + Args: + line (LineString): line + + Returns: + points (array): interpolated points, shape (N, 2) + ''' + + distances = list(np.arange(self.sample_dist, line.length, self.sample_dist)) + # make sure to sample at least two points when sample_dist > line.length + distances = [0,] + distances + [line.length,] + + sampled_points = np.array([list(line.interpolate(distance).coords) + for distance in distances]).squeeze() + + return sampled_points + + def get_vectorized_lines(self, map_geoms: Dict) -> Dict: + ''' Vectorize map elements. Iterate over the input dict and apply the + specified sample funcion. + + Args: + line (LineString): line + + Returns: + vectors (array): dict of vectorized map elements. + ''' + + vectors = {} + for label, geom_list in map_geoms.items(): + vectors[label] = [] + for geom in geom_list: + if geom.geom_type == 'LineString': + if self.simplify: + line = geom.simplify(0.2, preserve_topology=True) + line = np.array(line.coords) + else: + line = self.sample_fn(geom) + line = line[:, :self.coords_dim] + + if self.normalize: + line = self.normalize_line(line) + if self.permute: + line = self.permute_line(line) + vectors[label].append(line) + + elif geom.geom_type == 'Polygon': + # polygon objects will not be vectorized + continue + + else: + raise ValueError('map geoms must be either LineString or Polygon!') + return vectors + + def normalize_line(self, line: NDArray) -> NDArray: + ''' Convert points to range (0, 1). + + Args: + line (LineString): line + + Returns: + normalized (array): normalized points. + ''' + + origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2]) + + line[:, :2] = line[:, :2] - origin + + # transform from range [0, 1] to (0, 1) + eps = 1e-5 + line[:, :2] = line[:, :2] / (self.roi_size + eps) + + return line + + def permute_line(self, line: np.ndarray, padding=1e5): + ''' + (num_pts, 2) -> (num_permute, num_pts, 2) + where num_permute = 2 * (num_pts - 1) + ''' + is_closed = np.allclose(line[0], line[-1], atol=1e-3) + num_points = len(line) + permute_num = num_points - 1 + permute_lines_list = [] + if is_closed: + pts_to_permute = line[:-1, :] # throw away replicate start end pts + for shift_i in range(permute_num): + permute_lines_list.append(np.roll(pts_to_permute, shift_i, axis=0)) + flip_pts_to_permute = np.flip(pts_to_permute, axis=0) + for shift_i in range(permute_num): + permute_lines_list.append(np.roll(flip_pts_to_permute, shift_i, axis=0)) + else: + permute_lines_list.append(line) + permute_lines_list.append(np.flip(line, axis=0)) + + permute_lines_array = np.stack(permute_lines_list, axis=0) + + if is_closed: + tmp = np.zeros((permute_num * 2, num_points, self.coords_dim)) + tmp[:, :-1, :] = permute_lines_array + tmp[:, -1, :] = permute_lines_array[:, 0, :] # add replicate start end pts + permute_lines_array = tmp + + else: + # padding + padding = np.full([permute_num * 2 - 2, num_points, self.coords_dim], padding) + permute_lines_array = np.concatenate((permute_lines_array, padding), axis=0) + + return permute_lines_array + + def __call__(self, input_dict): + if "map_geoms" not in input_dict: + return input_dict + map_geoms = input_dict['map_geoms'] + vectors = self.get_vectorized_lines(map_geoms) + + if self.permute: + gt_map_labels, gt_map_pts = [], [] + for label, vecs in vectors.items(): + for vec in vecs: + gt_map_labels.append(label) + gt_map_pts.append(vec) + input_dict['gt_map_labels'] = np.array(gt_map_labels, dtype=np.int64) + input_dict['gt_map_pts'] = np.array(gt_map_pts, dtype=np.float32).reshape(-1, 2 * (self.sample_num - 1), self.sample_num, self.coords_dim) + else: + input_dict['vectors'] = DC(vectors, stack=False, cpu_only=True) + + return input_dict + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f'(simplify={self.simplify}, ' + repr_str += f'sample_num={self.sample_num}), ' + repr_str += f'sample_dist={self.sample_dist}), ' + repr_str += f'roi_size={self.roi_size})' + repr_str += f'normalize={self.normalize})' + repr_str += f'coords_dim={self.coords_dim})' + + return repr_str \ No newline at end of file diff --git a/projects/mmdet3d_plugin/datasets/samplers/__init__.py b/projects/mmdet3d_plugin/datasets/samplers/__init__.py new file mode 100644 index 0000000..17da039 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/samplers/__init__.py @@ -0,0 +1,6 @@ +from .group_sampler import DistributedGroupSampler +from .distributed_sampler import DistributedSampler +from .sampler import SAMPLER, build_sampler +from .group_in_batch_sampler import ( + GroupInBatchSampler, +) diff --git a/projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py b/projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py new file mode 100644 index 0000000..3cd9077 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py @@ -0,0 +1,82 @@ +import math + +import torch +from torch.utils.data import DistributedSampler as _DistributedSampler +from .sampler import SAMPLER + +import pdb +import sys + + +class ForkedPdb(pdb.Pdb): + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open("/dev/stdin") + pdb.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +def set_trace(): + ForkedPdb().set_trace(sys._getframe().f_back) + + +@SAMPLER.register_module() +class DistributedSampler(_DistributedSampler): + def __init__( + self, dataset=None, num_replicas=None, rank=None, shuffle=True, seed=0 + ): + super().__init__( + dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle + ) + # for the compatibility from PyTorch 1.3+ + self.seed = seed if seed is not None else 0 + + def __iter__(self): + # deterministically shuffle based on epoch + assert not self.shuffle + if "data_infos" in dir(self.dataset): + timestamps = [ + x["timestamp"] / 1e6 for x in self.dataset.data_infos + ] + vehicle_idx = [ + x["lidar_path"].split("/")[-1][:4] + if "lidar_path" in x + else None + for x in self.dataset.data_infos + ] + else: + timestamps = [ + x["timestamp"] / 1e6 + for x in self.dataset.datasets[0].data_infos + ] * len(self.dataset.datasets) + vehicle_idx = [ + x["lidar_path"].split("/")[-1][:4] + if "lidar_path" in x + else None + for x in self.dataset.datasets[0].data_infos + ] * len(self.dataset.datasets) + + sequence_splits = [] + for i in range(len(timestamps)): + if i == 0 or ( + abs(timestamps[i] - timestamps[i - 1]) > 4 + or vehicle_idx[i] != vehicle_idx[i - 1] + ): + sequence_splits.append([i]) + else: + sequence_splits[-1].append(i) + + indices = [] + perfix_sum = 0 + split_length = len(self.dataset) // self.num_replicas + for i in range(len(sequence_splits)): + if perfix_sum >= (self.rank + 1) * split_length: + break + elif perfix_sum >= self.rank * split_length: + indices.extend(sequence_splits[i]) + perfix_sum += len(sequence_splits[i]) + + self.num_samples = len(indices) + return iter(indices) diff --git a/projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py b/projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py new file mode 100644 index 0000000..62b5cb0 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py @@ -0,0 +1,178 @@ +# https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py +import itertools +import copy + +import numpy as np +import torch +import torch.distributed as dist +from mmcv.runner import get_dist_info +from torch.utils.data.sampler import Sampler + + +# https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157 +def sync_random_seed(seed=None, device="cuda"): + """Make sure different ranks share the same seed. + All workers must call this function, otherwise it will deadlock. + This method is generally used in `DistributedSampler`, + because the seed should be identical across all processes + in the distributed group. + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + dist.broadcast(random_num, src=0) + return random_num.item() + + +class GroupInBatchSampler(Sampler): + """ + Pardon this horrendous name. Basically, we want every sample to be from its own group. + If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on + its own group. + + Shuffling is only done for group order, not done within groups. + """ + + def __init__( + self, + dataset, + batch_size=1, + world_size=None, + rank=None, + seed=0, + skip_prob=0., + sequence_flip_prob=0., + ): + _rank, _world_size = get_dist_info() + if world_size is None: + world_size = _world_size + if rank is None: + rank = _rank + + self.dataset = dataset + self.batch_size = batch_size + self.world_size = world_size + self.rank = rank + self.seed = sync_random_seed(seed) + + self.size = len(self.dataset) + + assert hasattr(self.dataset, "flag") + self.flag = self.dataset.flag + self.group_sizes = np.bincount(self.flag) + self.groups_num = len(self.group_sizes) + self.global_batch_size = batch_size * world_size + assert self.groups_num >= self.global_batch_size + + # Now, for efficiency, make a dict group_idx: List[dataset sample_idxs] + self.group_idx_to_sample_idxs = { + group_idx: np.where(self.flag == group_idx)[0].tolist() + for group_idx in range(self.groups_num) + } + + # Get a generator per sample idx. Considering samples over all + # GPUs, each sample position has its own generator + self.group_indices_per_global_sample_idx = [ + self._group_indices_per_global_sample_idx( + self.rank * self.batch_size + local_sample_idx + ) + for local_sample_idx in range(self.batch_size) + ] + + # Keep track of a buffer of dataset sample idxs for each local sample idx + self.buffer_per_local_sample = [[] for _ in range(self.batch_size)] + self.aug_per_local_sample = [None for _ in range(self.batch_size)] + self.skip_prob = skip_prob + self.sequence_flip_prob = sequence_flip_prob + + def _infinite_group_indices(self): + g = torch.Generator() + g.manual_seed(self.seed) + while True: + yield from torch.randperm(self.groups_num, generator=g).tolist() + + def _group_indices_per_global_sample_idx(self, global_sample_idx): + yield from itertools.islice( + self._infinite_group_indices(), + global_sample_idx, + None, + self.global_batch_size, + ) + + def __iter__(self): + while True: + curr_batch = [] + for local_sample_idx in range(self.batch_size): + skip = ( + np.random.uniform() < self.skip_prob + and len(self.buffer_per_local_sample[local_sample_idx]) > 1 + ) + if len(self.buffer_per_local_sample[local_sample_idx]) == 0: + # Finished current group, refill with next group + # skip = False + new_group_idx = next( + self.group_indices_per_global_sample_idx[ + local_sample_idx + ] + ) + self.buffer_per_local_sample[ + local_sample_idx + ] = copy.deepcopy( + self.group_idx_to_sample_idxs[new_group_idx] + ) + if np.random.uniform() < self.sequence_flip_prob: + self.buffer_per_local_sample[ + local_sample_idx + ] = self.buffer_per_local_sample[local_sample_idx][ + ::-1 + ] + if self.dataset.keep_consistent_seq_aug: + self.aug_per_local_sample[ + local_sample_idx + ] = self.dataset.get_augmentation() + + if not self.dataset.keep_consistent_seq_aug: + self.aug_per_local_sample[ + local_sample_idx + ] = self.dataset.get_augmentation() + + if skip: + self.buffer_per_local_sample[local_sample_idx].pop(0) + curr_batch.append( + dict( + idx=self.buffer_per_local_sample[local_sample_idx].pop( + 0 + ), + aug_config=self.aug_per_local_sample[local_sample_idx], + ) + ) + + yield curr_batch + + def __len__(self): + """Length of base dataset.""" + return self.size + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/projects/mmdet3d_plugin/datasets/samplers/group_sampler.py b/projects/mmdet3d_plugin/datasets/samplers/group_sampler.py new file mode 100644 index 0000000..a14e114 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/samplers/group_sampler.py @@ -0,0 +1,119 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math + +import numpy as np +import torch +from mmcv.runner import get_dist_info +from torch.utils.data import Sampler +from .sampler import SAMPLER +import random +from IPython import embed + + +@SAMPLER.register_module() +class DistributedGroupSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + .. note:: + Dataset is assumed to be of constant size. + Arguments: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + seed (int, optional): random seed used to shuffle the sampler if + ``shuffle=True``. This number should be identical across all + processes in the distributed group. Default: 0. + """ + + def __init__( + self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0 + ): + _rank, _num_replicas = get_dist_info() + if num_replicas is None: + num_replicas = _num_replicas + if rank is None: + rank = _rank + self.dataset = dataset + self.samples_per_gpu = samples_per_gpu + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.seed = seed if seed is not None else 0 + + assert hasattr(self.dataset, "flag") + self.flag = self.dataset.flag + self.group_sizes = np.bincount(self.flag) + + self.num_samples = 0 + for i, j in enumerate(self.group_sizes): + self.num_samples += ( + int( + math.ceil( + self.group_sizes[i] + * 1.0 + / self.samples_per_gpu + / self.num_replicas + ) + ) + * self.samples_per_gpu + ) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch + self.seed) + + indices = [] + for i, size in enumerate(self.group_sizes): + if size > 0: + indice = np.where(self.flag == i)[0] + assert len(indice) == size + # add .numpy() to avoid bug when selecting indice in parrots. + # TODO: check whether torch.randperm() can be replaced by + # numpy.random.permutation(). + indice = indice[ + list(torch.randperm(int(size), generator=g).numpy()) + ].tolist() + extra = int( + math.ceil( + size * 1.0 / self.samples_per_gpu / self.num_replicas + ) + ) * self.samples_per_gpu * self.num_replicas - len(indice) + # pad indice + tmp = indice.copy() + for _ in range(extra // size): + indice.extend(tmp) + indice.extend(tmp[: extra % size]) + indices.extend(indice) + + assert len(indices) == self.total_size + + indices = [ + indices[j] + for i in list( + torch.randperm( + len(indices) // self.samples_per_gpu, generator=g + ) + ) + for j in range( + i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu + ) + ] + + # subsample + offset = self.num_samples * self.rank + indices = indices[offset : offset + self.num_samples] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/projects/mmdet3d_plugin/datasets/samplers/sampler.py b/projects/mmdet3d_plugin/datasets/samplers/sampler.py new file mode 100644 index 0000000..9bfc443 --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/samplers/sampler.py @@ -0,0 +1,7 @@ +from mmcv.utils.registry import Registry, build_from_cfg + +SAMPLER = Registry("sampler") + + +def build_sampler(cfg, default_args): + return build_from_cfg(cfg, SAMPLER, default_args) diff --git a/projects/mmdet3d_plugin/datasets/utils.py b/projects/mmdet3d_plugin/datasets/utils.py new file mode 100644 index 0000000..bf24b2f --- /dev/null +++ b/projects/mmdet3d_plugin/datasets/utils.py @@ -0,0 +1,225 @@ +import copy + +import cv2 +import numpy as np +import torch + +from projects.mmdet3d_plugin.core.box3d import * + + +def box3d_to_corners(box3d): + if isinstance(box3d, torch.Tensor): + box3d = box3d.detach().cpu().numpy() + corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1) + corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]] + # use relative origin [0.5, 0.5, 0] + corners_norm = corners_norm - np.array([0.5, 0.5, 0.5]) + corners = box3d[:, None, [W, L, H]] * corners_norm.reshape([1, 8, 3]) + + # rotate around z axis + rot_cos = np.cos(box3d[:, YAW]) + rot_sin = np.sin(box3d[:, YAW]) + rot_mat = np.tile(np.eye(3)[None], (box3d.shape[0], 1, 1)) + rot_mat[:, 0, 0] = rot_cos + rot_mat[:, 0, 1] = -rot_sin + rot_mat[:, 1, 0] = rot_sin + rot_mat[:, 1, 1] = rot_cos + corners = (rot_mat[:, None] @ corners[..., None]).squeeze(axis=-1) + corners += box3d[:, None, :3] + return corners + + +def plot_rect3d_on_img( + img, num_rects, rect_corners, color=(0, 255, 0), thickness=1 +): + """Plot the boundary lines of 3D rectangular on 2D images. + + Args: + img (numpy.array): The numpy array of image. + num_rects (int): Number of 3D rectangulars. + rect_corners (numpy.array): Coordinates of the corners of 3D + rectangulars. Should be in the shape of [num_rect, 8, 2]. + color (tuple[int], optional): The color to draw bboxes. + Default: (0, 255, 0). + thickness (int, optional): The thickness of bboxes. Default: 1. + """ + line_indices = ( + (0, 1), + (0, 3), + (0, 4), + (1, 2), + (1, 5), + (3, 2), + (3, 7), + (4, 5), + (4, 7), + (2, 6), + (5, 6), + (6, 7), + ) + h, w = img.shape[:2] + for i in range(num_rects): + corners = np.clip(rect_corners[i], -1e4, 1e5).astype(np.int32) + for start, end in line_indices: + if ( + (corners[start, 1] >= h or corners[start, 1] < 0) + or (corners[start, 0] >= w or corners[start, 0] < 0) + ) and ( + (corners[end, 1] >= h or corners[end, 1] < 0) + or (corners[end, 0] >= w or corners[end, 0] < 0) + ): + continue + if isinstance(color[0], int): + cv2.line( + img, + (corners[start, 0], corners[start, 1]), + (corners[end, 0], corners[end, 1]), + color, + thickness, + cv2.LINE_AA, + ) + else: + cv2.line( + img, + (corners[start, 0], corners[start, 1]), + (corners[end, 0], corners[end, 1]), + color[i], + thickness, + cv2.LINE_AA, + ) + + return img.astype(np.uint8) + + +def draw_lidar_bbox3d_on_img( + bboxes3d, raw_img, lidar2img_rt, img_metas=None, color=(0, 255, 0), thickness=1 +): + """Project the 3D bbox on 2D plane and draw on input image. + + Args: + bboxes3d (:obj:`LiDARInstance3DBoxes`): + 3d bbox in lidar coordinate system to visualize. + raw_img (numpy.array): The numpy array of image. + lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix + according to the camera intrinsic parameters. + img_metas (dict): Useless here. + color (tuple[int], optional): The color to draw bboxes. + Default: (0, 255, 0). + thickness (int, optional): The thickness of bboxes. Default: 1. + """ + img = raw_img.copy() + # corners_3d = bboxes3d.corners + corners_3d = box3d_to_corners(bboxes3d) + num_bbox = corners_3d.shape[0] + pts_4d = np.concatenate( + [corners_3d.reshape(-1, 3), np.ones((num_bbox * 8, 1))], axis=-1 + ) + lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4) + if isinstance(lidar2img_rt, torch.Tensor): + lidar2img_rt = lidar2img_rt.cpu().numpy() + pts_2d = pts_4d @ lidar2img_rt.T + + pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5) + pts_2d[:, 0] /= pts_2d[:, 2] + pts_2d[:, 1] /= pts_2d[:, 2] + imgfov_pts_2d = pts_2d[..., :2].reshape(num_bbox, 8, 2) + + return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness) + + +def draw_points_on_img(points, img, lidar2img_rt, color=(0, 255, 0), circle=4): + img = img.copy() + N = points.shape[0] + points = points.cpu().numpy() + lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4) + if isinstance(lidar2img_rt, torch.Tensor): + lidar2img_rt = lidar2img_rt.cpu().numpy() + pts_2d = ( + np.sum(points[:, :, None] * lidar2img_rt[:3, :3], axis=-1) + + lidar2img_rt[:3, 3] + ) + pts_2d[..., 2] = np.clip(pts_2d[..., 2], a_min=1e-5, a_max=1e5) + pts_2d = pts_2d[..., :2] / pts_2d[..., 2:3] + pts_2d = np.clip(pts_2d, -1e4, 1e4).astype(np.int32) + + for i in range(N): + for point in pts_2d[i]: + if isinstance(color[0], int): + color_tmp = color + else: + color_tmp = color[i] + cv2.circle(img, point.tolist(), circle, color_tmp, thickness=-1) + return img.astype(np.uint8) + + +def draw_lidar_bbox3d_on_bev( + bboxes_3d, bev_size, bev_range=115, color=(255, 0, 0), thickness=3): + if isinstance(bev_size, (list, tuple)): + bev_h, bev_w = bev_size + else: + bev_h, bev_w = bev_size, bev_size + bev = np.zeros([bev_h, bev_w, 3]) + + marking_color = (127, 127, 127) + bev_resolution = bev_range / bev_h + for cir in range(int(bev_range / 2 / 10)): + cv2.circle( + bev, + (int(bev_h / 2), int(bev_w / 2)), + int((cir + 1) * 10 / bev_resolution), + marking_color, + thickness=thickness, + ) + cv2.line( + bev, + (0, int(bev_h / 2)), + (bev_w, int(bev_h / 2)), + marking_color, + ) + cv2.line( + bev, + (int(bev_w / 2), 0), + (int(bev_w / 2), bev_h), + marking_color, + ) + if len(bboxes_3d) != 0: + bev_corners = box3d_to_corners(bboxes_3d)[:, [0, 3, 4, 7]][ + ..., [0, 1] + ] + xs = bev_corners[..., 0] / bev_resolution + bev_w / 2 + ys = -bev_corners[..., 1] / bev_resolution + bev_h / 2 + for obj_idx, (x, y) in enumerate(zip(xs, ys)): + for p1, p2 in ((0, 1), (0, 2), (1, 3), (2, 3)): + if isinstance(color[0], (list, tuple)): + tmp = color[obj_idx] + else: + tmp = color + cv2.line( + bev, + (int(x[p1]), int(y[p1])), + (int(x[p2]), int(y[p2])), + tmp, + thickness=thickness, + ) + return bev.astype(np.uint8) + + +def draw_lidar_bbox3d(bboxes_3d, imgs, lidar2imgs, color=(255, 0, 0)): + vis_imgs = [] + for i, (img, lidar2img) in enumerate(zip(imgs, lidar2imgs)): + vis_imgs.append( + draw_lidar_bbox3d_on_img(bboxes_3d, img, lidar2img, color=color) + ) + + num_imgs = len(vis_imgs) + if num_imgs < 4 or num_imgs % 2 != 0: + vis_imgs = np.concatenate(vis_imgs, axis=1) + else: + vis_imgs = np.concatenate([ + np.concatenate(vis_imgs[:num_imgs//2], axis=1), + np.concatenate(vis_imgs[num_imgs//2:], axis=1) + ], axis=0) + + bev = draw_lidar_bbox3d_on_bev(bboxes_3d, vis_imgs.shape[0], color=color) + vis_imgs = np.concatenate([bev, vis_imgs], axis=1) + return vis_imgs diff --git a/projects/mmdet3d_plugin/models/__init__.py b/projects/mmdet3d_plugin/models/__init__.py new file mode 100644 index 0000000..eb86c67 --- /dev/null +++ b/projects/mmdet3d_plugin/models/__init__.py @@ -0,0 +1,34 @@ +from .sparsedrive import SparseDrive +from .sparsedrive_head import SparseDriveHead +from .gt_sparse_drive_head import GTSparseDriveHead +from .blocks import ( + DeformableFeatureAggregation, + DenseDepthNet, + AsymmetricFFN, +) +from .instance_bank import InstanceBank +from .detection3d import ( + SparseBox3DDecoder, + SparseBox3DTarget, + SparseBox3DRefinementModule, + SparseBox3DKeyPointsGenerator, + SparseBox3DEncoder, +) +from .map import * +from .motion import * + + +__all__ = [ + "SparseDrive", + "SparseDriveHead", + "GTSparseDriveHead", + "DeformableFeatureAggregation", + "DenseDepthNet", + "AsymmetricFFN", + "InstanceBank", + "SparseBox3DDecoder", + "SparseBox3DTarget", + "SparseBox3DRefinementModule", + "SparseBox3DKeyPointsGenerator", + "SparseBox3DEncoder", +] diff --git a/projects/mmdet3d_plugin/models/attention.py b/projects/mmdet3d_plugin/models/attention.py new file mode 100644 index 0000000..9afe1ec --- /dev/null +++ b/projects/mmdet3d_plugin/models/attention.py @@ -0,0 +1,319 @@ +import warnings +import math + +import torch +import torch.nn as nn +from torch.nn.functional import linear +from torch.nn.init import xavier_uniform_, constant_ + +from mmcv.utils import deprecated_api_warning +from mmcv.runner import auto_fp16 +from mmcv.runner.base_module import BaseModule +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn.bricks.registry import ATTENTION +import torch.utils.checkpoint as cp + + +from einops import rearrange +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + print('Use flash_attn_unpadded_kvpacked_func') +except: + from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func + print('Use flash_attn_varlen_kvpacked_func') +from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis + + +def _in_projection_packed(q, k, v, w, b = None): + w_q, w_k, w_v = w.chunk(3) + if b is None: + b_q = b_k = b_v = None + else: + b_q, b_k, b_v = b.chunk(3) + return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v) + + +class FlashAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + """ + def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): + super().__init__() + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + self.fp16_enabled = True + + @auto_fp16(apply_to=('q', 'kv'), out_fp32=True) + def forward(self, q, kv, + causal=False, + key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, T, H, D) + kv: The tensor containing the key, and value. (B, S, 2, H, D) + key_padding_mask: a bool tensor of shape (B, S) + """ + assert q.dtype in [torch.float16, torch.bfloat16] and kv.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda and kv.is_cuda + assert q.shape[0] == kv.shape[0] and q.shape[-2] == kv.shape[-2] and q.shape[-1] == kv.shape[-1] + + batch_size = q.shape[0] + seqlen_q, seqlen_k = q.shape[1], kv.shape[1] + if key_padding_mask is None: + q, kv = rearrange(q, 'b s ... -> (b s) ...'), rearrange(kv, 'b s ... -> (b s) ...') + max_sq, max_sk = seqlen_q, seqlen_k + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, + device=q.device) + cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, + device=kv.device) + output = flash_attn_unpadded_kvpacked_func( + q, kv, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + nheads = kv.shape[-2] + q = rearrange(q, 'b s ... -> (b s) ...') + max_sq = seqlen_q + cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, + device=q.device) + x = rearrange(kv, 'b s two h d -> b s (two h d)') + x_unpad, indices, cu_seqlens_k, max_sk = unpad_input(x, key_padding_mask) + x_unpad = rearrange(x_unpad, 'nnz (two h d) -> nnz two h d', two=2, h=nheads) + output_unpad = flash_attn_unpadded_kvpacked_func( + q, x_unpad, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk, + self.dropout_p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + output = rearrange(output_unpad, '(b s) ... -> b s ...', b=batch_size) + + return output, None + + +class FlashMHA(nn.Module): + + def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0, + causal=False, device=None, dtype=None, **kwargs) -> None: + assert batch_first + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + self.bias = bias + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim))) + if bias: + self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self._reset_parameters() + + def _reset_parameters(self) -> None: + xavier_uniform_(self.in_proj_weight) + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.) + constant_(self.out_proj.bias, 0.) + + def forward(self, q, k, v, key_padding_mask=None): + """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) + key_padding_mask: bool tensor of shape (batch, seqlen) + """ + q, k, v = _in_projection_packed(q, k, v, self.in_proj_weight, self.in_proj_bias) + q = rearrange(q, 'b s (h d) -> b s h d', h=self.num_heads) + k = rearrange(k, 'b s (h d) -> b s h d', h=self.num_heads) + v = rearrange(v, 'b s (h d) -> b s h d', h=self.num_heads) + kv = torch.stack([k, v], dim=2) + + context, attn_weights = self.inner_attn(q, kv, key_padding_mask=key_padding_mask, causal=self.causal) + return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights + + +@ATTENTION.register_module() +class MultiheadFlashAttention(BaseModule): + """A wrapper for ``torch.nn.MultiheadAttention``. + This module implements MultiheadAttention with identity connection, + and positional encoding is also passed as input. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (agent:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (agent:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): When it is True, Key, Query and Value are shape of + (batch, n, embed_dim), otherwise (n, batch, embed_dim). + Default to False. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + init_cfg=None, + batch_first=True, + **kwargs): + super(MultiheadFlashAttention, self).__init__(init_cfg) + if 'dropout' in kwargs: + warnings.warn( + 'The arguments `dropout` in MultiheadAttention ' + 'has been deprecated, now you can separately ' + 'set `attn_drop`(float), proj_drop(float), ' + 'and `dropout_layer`(dict) ', DeprecationWarning) + attn_drop = kwargs['dropout'] + dropout_layer['drop_prob'] = kwargs.pop('dropout') + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.batch_first = True + self.attn = FlashMHA( + embed_dim=embed_dims, + num_heads=num_heads, + attention_dropout=attn_drop, + dtype=torch.float16, + device='cuda', + **kwargs + ) + + self.proj_drop = nn.Dropout(proj_drop) + self.dropout_layer = build_dropout( + dropout_layer) if dropout_layer else nn.Identity() + + @deprecated_api_warning({'residual': 'identity'}, + cls_name='MultiheadAttention') + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_pos=None, + attn_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `MultiheadAttention`. + **kwargs allow passing a more general data flow when combining + with other operations in `transformerlayer`. + Args: + query (Tensor): The input query with shape [num_queries, bs, + embed_dims] if self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + If None, the ``query`` will be used. Defaults to None. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + If None, the `key` will be used. + identity (Tensor): This tensor, with the same shape as x, + will be used for the identity link. + If None, `x` will be used. Defaults to None. + query_pos (Tensor): The positional encoding for query, with + the same shape as `x`. If not None, it will + be added to `x` before forward function. Defaults to None. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Defaults to None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. Defaults to None. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + Returns: + Tensor: forwarded results with shape + [num_queries, bs, embed_dims] + if self.batch_first is False, else + [bs, num_queries embed_dims]. + """ + if key is None: + key = query + if value is None: + value = key + if identity is None: + identity = query + if key_pos is None: + if query_pos is not None: + # use query_pos if key_pos is not available + if query_pos.shape == key.shape: + key_pos = query_pos + else: + warnings.warn(f'position encoding of key is' + f'missing in {self.__class__.__name__}.') + if query_pos is not None: + query = query + query_pos + if key_pos is not None: + key = key + key_pos + + if attn_mask is not None: + # FlashAttention does not support arbitrary attn_mask (e.g. DN training masks). + # Fall back to standard attention using the same projection weights. + import torch.nn.functional as F + out, _ = F.multi_head_attention_forward( + query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1), + self.attn.embed_dim, self.attn.num_heads, + self.attn.in_proj_weight, self.attn.in_proj_bias, + None, None, False, + self.attn.inner_attn.dropout_p if self.training else 0.0, + self.attn.out_proj.weight, self.attn.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + attn_mask=attn_mask, + ) + return identity + self.dropout_layer(self.proj_drop(out.transpose(0, 1))) + + # The dataflow('key', 'query', 'value') of ``FlashAttention`` is (batch, num_query, embed_dims). + if not self.batch_first: + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + out = self.attn( + q=query, + k=key, + v=value, + key_padding_mask=key_padding_mask)[0] + + if not self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + +def gen_sineembed_for_position(pos_tensor, hidden_dim=256): + """Mostly copy-paste from https://github.com/IDEA-opensource/DAB-DETR/ + """ + half_hidden_dim = hidden_dim // 2 + scale = 2 * math.pi + dim_t = torch.arange(half_hidden_dim, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim) + x_embed = pos_tensor[..., 0] * scale + y_embed = pos_tensor[..., 1] * scale + pos_x = x_embed[..., None] / dim_t + pos_y = y_embed[..., None] / dim_t + pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) + pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) + pos = torch.cat((pos_y, pos_x), dim=-1) + return pos + diff --git a/projects/mmdet3d_plugin/models/base_target.py b/projects/mmdet3d_plugin/models/base_target.py new file mode 100644 index 0000000..020a88d --- /dev/null +++ b/projects/mmdet3d_plugin/models/base_target.py @@ -0,0 +1,49 @@ +from abc import ABC, abstractmethod + + +__all__ = ["BaseTargetWithDenoising"] + + +class BaseTargetWithDenoising(ABC): + def __init__(self, num_dn_groups=0, num_temp_dn_groups=0): + super(BaseTargetWithDenoising, self).__init__() + self.num_dn_groups = num_dn_groups + self.num_temp_dn_groups = num_temp_dn_groups + self.dn_metas = None + + @abstractmethod + def sample(self, cls_pred, box_pred, cls_target, box_target): + """ + Perform Hungarian matching between predictions and ground truth, + returning the matched ground truth corresponding to the predictions + along with the corresponding regression weights. + """ + + def get_dn_anchors(self, cls_target, box_target, *args, **kwargs): + """ + Generate noisy instances for the current frame, with a total of + 'self.num_dn_groups' groups. + """ + return None + + def update_dn(self, instance_feature, anchor, *args, **kwargs): + """ + Insert the previously saved 'self.dn_metas' into the noisy instances + of the current frame. + """ + + def cache_dn( + self, + dn_instance_feature, + dn_anchor, + dn_cls_target, + valid_mask, + dn_id_target, + ): + """ + Randomly save information for 'self.num_temp_dn_groups' groups of + temporal noisy instances to 'self.dn_metas'. + """ + if self.num_temp_dn_groups < 0: + return + self.dn_metas = dict(dn_anchor=dn_anchor[:, : self.num_temp_dn_groups]) diff --git a/projects/mmdet3d_plugin/models/blocks.py b/projects/mmdet3d_plugin/models/blocks.py new file mode 100644 index 0000000..32cacdc --- /dev/null +++ b/projects/mmdet3d_plugin/models/blocks.py @@ -0,0 +1,393 @@ +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp.autocast_mode import autocast + +from mmcv.cnn import Linear, build_activation_layer, build_norm_layer +from mmcv.runner.base_module import Sequential, BaseModule +from mmcv.cnn.bricks.transformer import FFN +from mmcv.utils import build_from_cfg +from mmcv.cnn.bricks.drop import build_dropout +from mmcv.cnn import xavier_init, constant_init +from mmcv.cnn.bricks.registry import ( + ATTENTION, + PLUGIN_LAYERS, + FEEDFORWARD_NETWORK, +) + +try: + from ..ops import deformable_aggregation_function as DAF +except: + DAF = None + +__all__ = [ + "DeformableFeatureAggregation", + "DenseDepthNet", + "AsymmetricFFN", +] + + +def linear_relu_ln(embed_dims, in_loops, out_loops, input_dims=None): + if input_dims is None: + input_dims = embed_dims + layers = [] + for _ in range(out_loops): + for _ in range(in_loops): + layers.append(Linear(input_dims, embed_dims)) + layers.append(nn.ReLU(inplace=True)) + input_dims = embed_dims + layers.append(nn.LayerNorm(embed_dims)) + return layers + + +@ATTENTION.register_module() +class DeformableFeatureAggregation(BaseModule): + def __init__( + self, + embed_dims: int = 256, + num_groups: int = 8, + num_levels: int = 4, + num_cams: int = 6, + proj_drop: float = 0.0, + attn_drop: float = 0.0, + kps_generator: dict = None, + temporal_fusion_module=None, + use_temporal_anchor_embed=True, + use_deformable_func=False, + use_camera_embed=False, + residual_mode="add", + ): + super(DeformableFeatureAggregation, self).__init__() + if embed_dims % num_groups != 0: + raise ValueError( + f"embed_dims must be divisible by num_groups, " + f"but got {embed_dims} and {num_groups}" + ) + self.group_dims = int(embed_dims / num_groups) + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_groups = num_groups + self.num_cams = num_cams + self.use_temporal_anchor_embed = use_temporal_anchor_embed + if use_deformable_func: + assert DAF is not None, "deformable_aggregation needs to be set up." + self.use_deformable_func = use_deformable_func + self.attn_drop = attn_drop + self.residual_mode = residual_mode + self.proj_drop = nn.Dropout(proj_drop) + kps_generator["embed_dims"] = embed_dims + self.kps_generator = build_from_cfg(kps_generator, PLUGIN_LAYERS) + self.num_pts = self.kps_generator.num_pts + if temporal_fusion_module is not None: + if "embed_dims" not in temporal_fusion_module: + temporal_fusion_module["embed_dims"] = embed_dims + self.temp_module = build_from_cfg( + temporal_fusion_module, PLUGIN_LAYERS + ) + else: + self.temp_module = None + self.output_proj = Linear(embed_dims, embed_dims) + + if use_camera_embed: + self.camera_encoder = Sequential( + *linear_relu_ln(embed_dims, 1, 2, 12) + ) + self.weights_fc = Linear( + embed_dims, num_groups * num_levels * self.num_pts + ) + else: + self.camera_encoder = None + self.weights_fc = Linear( + embed_dims, num_groups * num_cams * num_levels * self.num_pts + ) + + def init_weight(self): + constant_init(self.weights_fc, val=0.0, bias=0.0) + xavier_init(self.output_proj, distribution="uniform", bias=0.0) + + def forward( + self, + instance_feature: torch.Tensor, + anchor: torch.Tensor, + anchor_embed: torch.Tensor, + feature_maps: List[torch.Tensor], + metas: dict, + **kwargs: dict, + ): + bs, num_anchor = instance_feature.shape[:2] + key_points = self.kps_generator(anchor, instance_feature) + weights = self._get_weights(instance_feature, anchor_embed, metas) + + if self.use_deformable_func: + points_2d = ( + self.project_points( + key_points, + metas["projection_mat"], + metas.get("image_wh"), + ) + .permute(0, 2, 3, 1, 4) + .reshape(bs, num_anchor, self.num_pts, self.num_cams, 2) + ) + weights = ( + weights.permute(0, 1, 4, 2, 3, 5) + .contiguous() + .reshape( + bs, + num_anchor, + self.num_pts, + self.num_cams, + self.num_levels, + self.num_groups, + ) + ) + features = DAF(*feature_maps, points_2d, weights).reshape( + bs, num_anchor, self.embed_dims + ) + else: + features = self.feature_sampling( + feature_maps, + key_points, + metas["projection_mat"], + metas.get("image_wh"), + ) + features = self.multi_view_level_fusion(features, weights) + features = features.sum(dim=2) # fuse multi-point features + output = self.proj_drop(self.output_proj(features)) + if self.residual_mode == "add": + output = output + instance_feature + elif self.residual_mode == "cat": + output = torch.cat([output, instance_feature], dim=-1) + return output + + def _get_weights(self, instance_feature, anchor_embed, metas=None): + bs, num_anchor = instance_feature.shape[:2] + feature = instance_feature + anchor_embed + if self.camera_encoder is not None: + camera_embed = self.camera_encoder( + metas["projection_mat"][:, :, :3].reshape( + bs, self.num_cams, -1 + ) + ) + feature = feature[:, :, None] + camera_embed[:, None] + + weights = ( + self.weights_fc(feature) + .reshape(bs, num_anchor, -1, self.num_groups) + .softmax(dim=-2) + .reshape( + bs, + num_anchor, + self.num_cams, + self.num_levels, + self.num_pts, + self.num_groups, + ) + ) + if self.training and self.attn_drop > 0: + mask = torch.rand( + bs, num_anchor, self.num_cams, 1, self.num_pts, 1 + ) + mask = mask.to(device=weights.device, dtype=weights.dtype) + weights = ((mask > self.attn_drop) * weights) / ( + 1 - self.attn_drop + ) + return weights + + @staticmethod + def project_points(key_points, projection_mat, image_wh=None): + bs, num_anchor, num_pts = key_points.shape[:3] + + pts_extend = torch.cat( + [key_points, torch.ones_like(key_points[..., :1])], dim=-1 + ) + points_2d = torch.matmul( + projection_mat[:, :, None, None], pts_extend[:, None, ..., None] + ).squeeze(-1) + points_2d = points_2d[..., :2] / torch.clamp( + points_2d[..., 2:3], min=1e-5 + ) + if image_wh is not None: + points_2d = points_2d / image_wh[:, :, None, None] + return points_2d + + @staticmethod + def feature_sampling( + feature_maps: List[torch.Tensor], + key_points: torch.Tensor, + projection_mat: torch.Tensor, + image_wh: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + num_levels = len(feature_maps) + num_cams = feature_maps[0].shape[1] + bs, num_anchor, num_pts = key_points.shape[:3] + + points_2d = DeformableFeatureAggregation.project_points( + key_points, projection_mat, image_wh + ) + points_2d = points_2d * 2 - 1 + points_2d = points_2d.flatten(end_dim=1) + + features = [] + for fm in feature_maps: + features.append( + torch.nn.functional.grid_sample( + fm.flatten(end_dim=1), points_2d + ) + ) + features = torch.stack(features, dim=1) + features = features.reshape( + bs, num_cams, num_levels, -1, num_anchor, num_pts + ).permute( + 0, 4, 1, 2, 5, 3 + ) # bs, num_anchor, num_cams, num_levels, num_pts, embed_dims + + return features + + def multi_view_level_fusion( + self, + features: torch.Tensor, + weights: torch.Tensor, + ): + bs, num_anchor = weights.shape[:2] + features = weights[..., None] * features.reshape( + features.shape[:-1] + (self.num_groups, self.group_dims) + ) + features = features.sum(dim=2).sum(dim=2) + features = features.reshape( + bs, num_anchor, self.num_pts, self.embed_dims + ) + return features + + +@PLUGIN_LAYERS.register_module() +class DenseDepthNet(BaseModule): + def __init__( + self, + embed_dims=256, + num_depth_layers=1, + equal_focal=100, + max_depth=60, + loss_weight=1.0, + ): + super().__init__() + self.embed_dims = embed_dims + self.equal_focal = equal_focal + self.num_depth_layers = num_depth_layers + self.max_depth = max_depth + self.loss_weight = loss_weight + + self.depth_layers = nn.ModuleList() + for i in range(num_depth_layers): + self.depth_layers.append( + nn.Conv2d(embed_dims, 1, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, feature_maps, focal=None, gt_depths=None): + if focal is None: + focal = self.equal_focal + else: + focal = focal.reshape(-1) + depths = [] + for i, feat in enumerate(feature_maps[: self.num_depth_layers]): + depth = self.depth_layers[i](feat.flatten(end_dim=1).float()).exp() + depth = depth.transpose(0, -1) * focal / self.equal_focal + depth = depth.transpose(0, -1) + depths.append(depth) + if gt_depths is not None and self.training: + loss = self.loss(depths, gt_depths) + return loss + return depths + + def loss(self, depth_preds, gt_depths): + loss = 0.0 + for pred, gt in zip(depth_preds, gt_depths): + pred = pred.permute(0, 2, 3, 1).contiguous().reshape(-1) + gt = gt.reshape(-1) + fg_mask = torch.logical_and( + gt > 0.0, torch.logical_not(torch.isnan(pred)) + ) + gt = gt[fg_mask] + pred = pred[fg_mask] + pred = torch.clip(pred, 0.0, self.max_depth) + with autocast(enabled=False): + error = torch.abs(pred - gt).sum() + _loss = ( + error + / max(1.0, len(gt) * len(depth_preds)) + * self.loss_weight + ) + loss = loss + _loss + return loss + + +@FEEDFORWARD_NETWORK.register_module() +class AsymmetricFFN(BaseModule): + def __init__( + self, + in_channels=None, + pre_norm=None, + embed_dims=256, + feedforward_channels=1024, + num_fcs=2, + act_cfg=dict(type="ReLU", inplace=True), + ffn_drop=0.0, + dropout_layer=None, + add_identity=True, + init_cfg=None, + **kwargs, + ): + super(AsymmetricFFN, self).__init__(init_cfg) + assert num_fcs >= 2, ( + "num_fcs should be no less " f"than 2. got {num_fcs}." + ) + self.in_channels = in_channels + self.pre_norm = pre_norm + self.embed_dims = embed_dims + self.feedforward_channels = feedforward_channels + self.num_fcs = num_fcs + self.act_cfg = act_cfg + self.activate = build_activation_layer(act_cfg) + + layers = [] + if in_channels is None: + in_channels = embed_dims + if pre_norm is not None: + self.pre_norm = build_norm_layer(pre_norm, in_channels)[1] + + for _ in range(num_fcs - 1): + layers.append( + Sequential( + Linear(in_channels, feedforward_channels), + self.activate, + nn.Dropout(ffn_drop), + ) + ) + in_channels = feedforward_channels + layers.append(Linear(feedforward_channels, embed_dims)) + layers.append(nn.Dropout(ffn_drop)) + self.layers = Sequential(*layers) + self.dropout_layer = ( + build_dropout(dropout_layer) + if dropout_layer + else torch.nn.Identity() + ) + self.add_identity = add_identity + if self.add_identity: + self.identity_fc = ( + torch.nn.Identity() + if in_channels == embed_dims + else Linear(self.in_channels, embed_dims) + ) + + def forward(self, x, identity=None): + if self.pre_norm is not None: + x = self.pre_norm(x) + out = self.layers(x) + if not self.add_identity: + return self.dropout_layer(out) + if identity is None: + identity = x + identity = self.identity_fc(identity) + return identity + self.dropout_layer(out) diff --git a/projects/mmdet3d_plugin/models/detection3d/__init__.py b/projects/mmdet3d_plugin/models/detection3d/__init__.py new file mode 100644 index 0000000..b2654b6 --- /dev/null +++ b/projects/mmdet3d_plugin/models/detection3d/__init__.py @@ -0,0 +1,9 @@ +from .decoder import SparseBox3DDecoder +from .target import SparseBox3DTarget +from .detection3d_blocks import ( + SparseBox3DRefinementModule, + SparseBox3DKeyPointsGenerator, + SparseBox3DEncoder, +) +from .losses import SparseBox3DLoss +from .detection3d_head import Sparse4DHead diff --git a/projects/mmdet3d_plugin/models/detection3d/decoder.py b/projects/mmdet3d_plugin/models/detection3d/decoder.py new file mode 100644 index 0000000..5d000be --- /dev/null +++ b/projects/mmdet3d_plugin/models/detection3d/decoder.py @@ -0,0 +1,115 @@ +from typing import Optional + +import torch + +from mmdet.core.bbox.builder import BBOX_CODERS + +from projects.mmdet3d_plugin.core.box3d import * + +def decode_box(box): + yaw = torch.atan2(box[..., SIN_YAW], box[..., COS_YAW]) + box = torch.cat( + [ + box[..., [X, Y, Z]], + box[..., [W, L, H]].exp(), + yaw[..., None], + box[..., VX:], + ], + dim=-1, + ) + return box + + +@BBOX_CODERS.register_module() +class SparseBox3DDecoder(object): + def __init__( + self, + num_output: int = 300, + score_threshold: Optional[float] = None, + sorted: bool = True, + ): + super(SparseBox3DDecoder, self).__init__() + self.num_output = num_output + self.score_threshold = score_threshold + self.sorted = sorted + + def decode( + self, + cls_scores, + box_preds, + instance_id=None, + quality=None, + visibility=None, + output_idx=-1, + ): + squeeze_cls = instance_id is not None + + cls_scores = cls_scores[output_idx].sigmoid() + + if squeeze_cls: + cls_scores, cls_ids = cls_scores.max(dim=-1) + cls_scores = cls_scores.unsqueeze(dim=-1) + + box_preds = box_preds[output_idx] + bs, num_pred, num_cls = cls_scores.shape + cls_scores, indices = cls_scores.flatten(start_dim=1).topk( + self.num_output, dim=1, sorted=self.sorted + ) + if not squeeze_cls: + cls_ids = indices % num_cls + if self.score_threshold is not None: + mask = cls_scores >= self.score_threshold + + if visibility is not None and visibility[output_idx] is None: + visibility = None + if quality[output_idx] is None: + quality = None + if quality is not None: + centerness = quality[output_idx][..., CNS] + centerness = torch.gather(centerness, 1, indices // num_cls) + cls_scores_origin = cls_scores.clone() + cls_scores *= centerness.sigmoid() + cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True) + if not squeeze_cls: + cls_ids = torch.gather(cls_ids, 1, idx) + if self.score_threshold is not None: + mask = torch.gather(mask, 1, idx) + indices = torch.gather(indices, 1, idx) + + output = [] + for i in range(bs): + category_ids = cls_ids[i] + if squeeze_cls: + category_ids = category_ids[indices[i]] + scores = cls_scores[i] + box = box_preds[i, indices[i] // num_cls] + if self.score_threshold is not None: + category_ids = category_ids[mask[i]] + scores = scores[mask[i]] + box = box[mask[i]] + if quality is not None: + scores_origin = cls_scores_origin[i] + if self.score_threshold is not None: + scores_origin = scores_origin[mask[i]] + + box = decode_box(box) + output.append( + { + "boxes_3d": box.cpu(), + "scores_3d": scores.cpu(), + "labels_3d": category_ids.cpu(), + } + ) + if quality is not None: + output[-1]["cls_scores"] = scores_origin.cpu() + if instance_id is not None: + ids = instance_id[i, indices[i]] + if self.score_threshold is not None: + ids = ids[mask[i]] + output[-1]["instance_ids"] = ids + if visibility is not None: + vis_i = visibility[output_idx][i, indices[i] // num_cls, 0].sigmoid() + if self.score_threshold is not None: + vis_i = vis_i[mask[i]] + output[-1]["visibility_scores"] = vis_i.cpu() + return output diff --git a/projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py b/projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py new file mode 100644 index 0000000..2bcabe8 --- /dev/null +++ b/projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py @@ -0,0 +1,316 @@ +import torch +import torch.nn as nn +import numpy as np + +from mmcv.cnn import Linear, Scale, bias_init_with_prob +from mmcv.runner.base_module import Sequential, BaseModule +from mmcv.cnn import xavier_init +from mmcv.cnn.bricks.registry import ( + PLUGIN_LAYERS, + POSITIONAL_ENCODING, +) + +from projects.mmdet3d_plugin.core.box3d import * +from ..blocks import linear_relu_ln + +__all__ = [ + "SparseBox3DRefinementModule", + "SparseBox3DKeyPointsGenerator", + "SparseBox3DEncoder", +] + + +@POSITIONAL_ENCODING.register_module() +class SparseBox3DEncoder(BaseModule): + def __init__( + self, + embed_dims, + vel_dims=3, + mode="add", + output_fc=True, + in_loops=1, + out_loops=2, + ): + super().__init__() + assert mode in ["add", "cat"] + self.embed_dims = embed_dims + self.vel_dims = vel_dims + self.mode = mode + + def embedding_layer(input_dims, output_dims): + return nn.Sequential( + *linear_relu_ln(output_dims, in_loops, out_loops, input_dims) + ) + + if not isinstance(embed_dims, (list, tuple)): + embed_dims = [embed_dims] * 5 + self.pos_fc = embedding_layer(3, embed_dims[0]) + self.size_fc = embedding_layer(3, embed_dims[1]) + self.yaw_fc = embedding_layer(2, embed_dims[2]) + if vel_dims > 0: + self.vel_fc = embedding_layer(self.vel_dims, embed_dims[3]) + if output_fc: + self.output_fc = embedding_layer(embed_dims[-1], embed_dims[-1]) + else: + self.output_fc = None + + def forward(self, box_3d: torch.Tensor): + pos_feat = self.pos_fc(box_3d[..., [X, Y, Z]]) + size_feat = self.size_fc(box_3d[..., [W, L, H]]) + yaw_feat = self.yaw_fc(box_3d[..., [SIN_YAW, COS_YAW]]) + if self.mode == "add": + output = pos_feat + size_feat + yaw_feat + elif self.mode == "cat": + output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1) + + if self.vel_dims > 0: + vel_feat = self.vel_fc(box_3d[..., VX : VX + self.vel_dims]) + if self.mode == "add": + output = output + vel_feat + elif self.mode == "cat": + output = torch.cat([output, vel_feat], dim=-1) + if self.output_fc is not None: + output = self.output_fc(output) + return output + + +@PLUGIN_LAYERS.register_module() +class SparseBox3DRefinementModule(BaseModule): + def __init__( + self, + embed_dims=256, + output_dim=11, + num_cls=10, + normalize_yaw=False, + refine_yaw=False, + with_cls_branch=True, + with_quality_estimation=False, + with_visibility_estimation=False, + ): + super(SparseBox3DRefinementModule, self).__init__() + self.embed_dims = embed_dims + self.output_dim = output_dim + self.num_cls = num_cls + self.normalize_yaw = normalize_yaw + self.refine_yaw = refine_yaw + + self.refine_state = [X, Y, Z, W, L, H] + if self.refine_yaw: + self.refine_state += [SIN_YAW, COS_YAW] + + self.layers = nn.Sequential( + *linear_relu_ln(embed_dims, 2, 2), + Linear(self.embed_dims, self.output_dim), + Scale([1.0] * self.output_dim), + ) + self.with_cls_branch = with_cls_branch + if with_cls_branch: + self.cls_layers = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 2), + Linear(self.embed_dims, self.num_cls), + ) + self.with_quality_estimation = with_quality_estimation + if with_quality_estimation: + self.quality_layers = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 2), + Linear(self.embed_dims, 2), + ) + self.with_visibility_estimation = with_visibility_estimation + if with_visibility_estimation: + self.visibility_layers = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 2), + Linear(self.embed_dims, 1), + ) + + def init_weight(self): + if self.with_cls_branch: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.cls_layers[-1].bias, bias_init) + if self.with_visibility_estimation: + # Neutral prior — model starts with no bias toward visible/occluded + nn.init.constant_(self.visibility_layers[-1].bias, 0.0) + + def forward( + self, + instance_feature: torch.Tensor, + anchor: torch.Tensor, + anchor_embed: torch.Tensor, + time_interval: torch.Tensor = 1.0, + return_cls=True, + ): + feature = instance_feature + anchor_embed + output = self.layers(feature) + output[..., self.refine_state] = ( + output[..., self.refine_state] + anchor[..., self.refine_state] + ) + if self.normalize_yaw: + output[..., [SIN_YAW, COS_YAW]] = torch.nn.functional.normalize( + output[..., [SIN_YAW, COS_YAW]], dim=-1 + ) + if self.output_dim > 8: + if not isinstance(time_interval, torch.Tensor): + time_interval = instance_feature.new_tensor(time_interval) + translation = torch.transpose(output[..., VX:], 0, -1) + velocity = torch.transpose(translation / time_interval, 0, -1) + output[..., VX:] = velocity + anchor[..., VX:] + + if return_cls: + assert self.with_cls_branch, "Without classification layers !!!" + cls = self.cls_layers(instance_feature) + else: + cls = None + if return_cls and self.with_quality_estimation: + quality = self.quality_layers(feature) + else: + quality = None + if return_cls and self.with_visibility_estimation: + # (bs, N, 1) — raw logit; 1 = sensor-visible, 0 = occluded + visibility = self.visibility_layers(instance_feature) + else: + visibility = None + return output, cls, quality, visibility + + +@PLUGIN_LAYERS.register_module() +class SparseBox3DKeyPointsGenerator(BaseModule): + def __init__( + self, + embed_dims=256, + num_learnable_pts=0, + fix_scale=None, + ): + super(SparseBox3DKeyPointsGenerator, self).__init__() + self.embed_dims = embed_dims + self.num_learnable_pts = num_learnable_pts + if fix_scale is None: + fix_scale = ((0.0, 0.0, 0.0),) + self.fix_scale = nn.Parameter( + torch.tensor(fix_scale), requires_grad=False + ) + self.num_pts = len(self.fix_scale) + num_learnable_pts + if num_learnable_pts > 0: + self.learnable_fc = Linear(self.embed_dims, num_learnable_pts * 3) + + def init_weight(self): + if self.num_learnable_pts > 0: + xavier_init(self.learnable_fc, distribution="uniform", bias=0.0) + + def forward( + self, + anchor, + instance_feature=None, + T_cur2temp_list=None, + cur_timestamp=None, + temp_timestamps=None, + ): + bs, num_anchor = anchor.shape[:2] + size = anchor[..., None, [W, L, H]].exp() + key_points = self.fix_scale * size + if self.num_learnable_pts > 0 and instance_feature is not None: + learnable_scale = ( + self.learnable_fc(instance_feature) + .reshape(bs, num_anchor, self.num_learnable_pts, 3) + .sigmoid() + - 0.5 + ) + key_points = torch.cat( + [key_points, learnable_scale * size], dim=-2 + ) + + rotation_mat = anchor.new_zeros([bs, num_anchor, 3, 3]) + + rotation_mat[:, :, 0, 0] = anchor[:, :, COS_YAW] + rotation_mat[:, :, 0, 1] = -anchor[:, :, SIN_YAW] + rotation_mat[:, :, 1, 0] = anchor[:, :, SIN_YAW] + rotation_mat[:, :, 1, 1] = anchor[:, :, COS_YAW] + rotation_mat[:, :, 2, 2] = 1 + + key_points = torch.matmul( + rotation_mat[:, :, None], key_points[..., None] + ).squeeze(-1) + key_points = key_points + anchor[..., None, [X, Y, Z]] + + if ( + cur_timestamp is None + or temp_timestamps is None + or T_cur2temp_list is None + or len(temp_timestamps) == 0 + ): + return key_points + + temp_key_points_list = [] + velocity = anchor[..., VX:] + for i, t_time in enumerate(temp_timestamps): + time_interval = cur_timestamp - t_time + translation = ( + velocity + * time_interval.to(dtype=velocity.dtype)[:, None, None] + ) + temp_key_points = key_points - translation[:, :, None] + T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype) + temp_key_points = ( + T_cur2temp[:, None, None, :3] + @ torch.cat( + [ + temp_key_points, + torch.ones_like(temp_key_points[..., :1]), + ], + dim=-1, + ).unsqueeze(-1) + ) + temp_key_points = temp_key_points.squeeze(-1) + temp_key_points_list.append(temp_key_points) + return key_points, temp_key_points_list + + @staticmethod + def anchor_projection( + anchor, + T_src2dst_list, + src_timestamp=None, + dst_timestamps=None, + time_intervals=None, + ): + dst_anchors = [] + for i in range(len(T_src2dst_list)): + vel = anchor[..., VX:] + vel_dim = vel.shape[-1] + T_src2dst = torch.unsqueeze( + T_src2dst_list[i].to(dtype=anchor.dtype), dim=1 + ) + + center = anchor[..., [X, Y, Z]] + if time_intervals is not None: + time_interval = time_intervals[i] + elif src_timestamp is not None and dst_timestamps is not None: + time_interval = (src_timestamp - dst_timestamps[i]).to( + dtype=vel.dtype + ) + else: + time_interval = None + if time_interval is not None: + translation = vel.transpose(0, -1) * time_interval + translation = translation.transpose(0, -1) + center = center - translation + + center = ( + torch.matmul( + T_src2dst[..., :3, :3], center[..., None] + ).squeeze(dim=-1) + + T_src2dst[..., :3, 3] + ) + size = anchor[..., [W, L, H]] + yaw = torch.matmul( + T_src2dst[..., :2, :2], + anchor[..., [COS_YAW, SIN_YAW], None], + ).squeeze(-1) + yaw = yaw[..., [1,0]] + vel = torch.matmul( + T_src2dst[..., :vel_dim, :vel_dim], vel[..., None] + ).squeeze(-1) + dst_anchor = torch.cat([center, size, yaw, vel], dim=-1) + dst_anchors.append(dst_anchor) + return dst_anchors + + @staticmethod + def distance(anchor): + return torch.norm(anchor[..., :2], p=2, dim=-1) diff --git a/projects/mmdet3d_plugin/models/detection3d/detection3d_head.py b/projects/mmdet3d_plugin/models/detection3d/detection3d_head.py new file mode 100644 index 0000000..da39352 --- /dev/null +++ b/projects/mmdet3d_plugin/models/detection3d/detection3d_head.py @@ -0,0 +1,790 @@ +from typing import List, Optional, Tuple, Union +import warnings + +import numpy as np +import torch +import torch.nn as nn + +from mmcv.cnn.bricks.registry import ( + ATTENTION, + PLUGIN_LAYERS, + POSITIONAL_ENCODING, + FEEDFORWARD_NETWORK, + NORM_LAYERS, +) +from mmcv.runner import BaseModule, force_fp32 +from mmcv.utils import build_from_cfg +from mmdet.core.bbox.builder import BBOX_SAMPLERS +from mmdet.core.bbox.builder import BBOX_CODERS +from mmdet.models import HEADS, LOSSES +from mmdet.core import reduce_mean + +from ..blocks import DeformableFeatureAggregation as DFG + +__all__ = ["Sparse4DHead"] + + +@HEADS.register_module() +class Sparse4DHead(BaseModule): + def __init__( + self, + instance_bank: dict, + anchor_encoder: dict, + graph_model: dict, + norm_layer: dict, + ffn: dict, + deformable_model: dict, + refine_layer: dict, + num_decoder: int = 6, + num_single_frame_decoder: int = -1, + temp_graph_model: dict = None, + loss_cls: dict = None, + loss_reg: dict = None, + loss_visibility: dict = None, + decoder: dict = None, + sampler: dict = None, + gt_cls_key: str = "gt_labels_3d", + gt_reg_key: str = "gt_bboxes_3d", + gt_id_key: str = "instance_id", + gt_visibility_key: str = "gt_visibility", + with_instance_id: bool = True, + task_prefix: str = 'det', + reg_weights: List = None, + operation_order: Optional[List[str]] = None, + cls_threshold_to_reg: float = -1, + dn_loss_weight: float = 5.0, + decouple_attn: bool = True, + temporal_warmup_order: Optional[List[str]] = None, + warmup_refine_layer: dict = None, + warmup_ffn: dict = None, + warmup_supervise_all: bool = False, + init_cfg: dict = None, + **kwargs, + ): + super(Sparse4DHead, self).__init__(init_cfg) + self.num_decoder = num_decoder + self.num_single_frame_decoder = num_single_frame_decoder + self.gt_cls_key = gt_cls_key + self.gt_reg_key = gt_reg_key + self.gt_id_key = gt_id_key + self.with_instance_id = with_instance_id + self.task_prefix = task_prefix + self.cls_threshold_to_reg = cls_threshold_to_reg + self.dn_loss_weight = dn_loss_weight + self.decouple_attn = decouple_attn + + if reg_weights is None: + self.reg_weights = [1.0] * 10 + else: + self.reg_weights = reg_weights + + if operation_order is None: + operation_order = [ + "temp_gnn", + "gnn", + "norm", + "deformable", + "norm", + "ffn", + "norm", + "refine", + ] * num_decoder + # delete the 'gnn' and 'norm' layers in the first transformer blocks + operation_order = operation_order[3:] + self.operation_order = operation_order + + # =========== build modules =========== + def build(cfg, registry): + if cfg is None: + return None + return build_from_cfg(cfg, registry) + + self.gt_visibility_key = gt_visibility_key + self.instance_bank = build(instance_bank, PLUGIN_LAYERS) + self.anchor_encoder = build(anchor_encoder, POSITIONAL_ENCODING) + self.sampler = build(sampler, BBOX_SAMPLERS) + self.decoder = build(decoder, BBOX_CODERS) + self.loss_cls = build(loss_cls, LOSSES) + self.loss_reg = build(loss_reg, LOSSES) + self.loss_visibility = build(loss_visibility, LOSSES) if loss_visibility else None + self.op_config_map = { + "temp_gnn": [temp_graph_model, ATTENTION], + "gnn": [graph_model, ATTENTION], + "norm": [norm_layer, NORM_LAYERS], + "ffn": [ffn, FEEDFORWARD_NETWORK], + "deformable": [deformable_model, ATTENTION], + "refine": [refine_layer, PLUGIN_LAYERS], + } + self.layers = nn.ModuleList( + [ + build(*self.op_config_map.get(op, [None, None])) + for op in self.operation_order + ] + ) + self.temporal_warmup_order = list(temporal_warmup_order) if temporal_warmup_order else [] + # For "refine" in warmup, use warmup_refine_layer if provided (should have + # with_cls_branch=False, with_quality_estimation=False to avoid unused params). + # Falls back to refine_layer if warmup_refine_layer is not specified. + warmup_op_config_map = dict(self.op_config_map) + if warmup_refine_layer is not None: + warmup_op_config_map["refine"] = [warmup_refine_layer, PLUGIN_LAYERS] + if warmup_ffn is not None: + warmup_op_config_map["ffn"] = [warmup_ffn, FEEDFORWARD_NETWORK] + self.warmup_layers = nn.ModuleList( + [ + build(*warmup_op_config_map.get(op, [None, None])) + for op in self.temporal_warmup_order + ] + ) + self.embed_dims = self.instance_bank.embed_dims + if self.decouple_attn: + self.fc_before = nn.Linear( + self.embed_dims, self.embed_dims * 2, bias=False + ) + self.fc_after = nn.Linear( + self.embed_dims * 2, self.embed_dims, bias=False + ) + else: + self.fc_before = nn.Identity() + self.fc_after = nn.Identity() + # Dedicated fc projections for warmup GNN — must NOT share with fc_before/fc_after + # because gradient checkpointing would fire DDP hooks twice for shared params. + has_warmup_gnn = any(op == "gnn" for op in self.temporal_warmup_order) + if has_warmup_gnn and self.decouple_attn: + self.warmup_fc_before = nn.Linear( + self.embed_dims, self.embed_dims * 2, bias=False + ) + self.warmup_fc_after = nn.Linear( + self.embed_dims * 2, self.embed_dims, bias=False + ) + else: + self.warmup_fc_before = nn.Identity() + self.warmup_fc_after = nn.Identity() + self.warmup_supervise_all = warmup_supervise_all + + def init_weights(self): + for i, op in enumerate(self.operation_order): + if self.layers[i] is None: + continue + elif op != "refine": + for p in self.layers[i].parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for i, op in enumerate(self.temporal_warmup_order): + if self.warmup_layers[i] is None: + continue + elif op != "refine": + for p in self.warmup_layers[i].parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + if isinstance(self.warmup_fc_before, nn.Linear): + nn.init.xavier_uniform_(self.warmup_fc_before.weight) + if isinstance(self.warmup_fc_after, nn.Linear): + nn.init.xavier_uniform_(self.warmup_fc_after.weight) + for m in self.modules(): + if hasattr(m, "init_weight"): + m.init_weight() + + def _gnn_with_layer(self, layer, feat, anchor_embed): + """Self-attention (GNN) using an explicit layer rather than self.layers[i]. + Used by the temporal warmup block where queries attend only to each other. + Uses warmup_fc_before/after (not shared with main decoder fc projections).""" + if self.decouple_attn: + q = torch.cat([feat, anchor_embed], dim=-1) + v = self.warmup_fc_before(feat) + return self.warmup_fc_after(layer(q, q, v)) + else: + v = self.warmup_fc_before(feat) + return self.warmup_fc_after(layer(feat, feat, v, query_pos=anchor_embed)) + + def graph_model( + self, + index, + query, + key=None, + value=None, + query_pos=None, + key_pos=None, + **kwargs, + ): + if self.decouple_attn: + query = torch.cat([query, query_pos], dim=-1) + if key is not None: + key = torch.cat([key, key_pos], dim=-1) + query_pos, key_pos = None, None + if value is not None: + value = self.fc_before(value) + return self.fc_after( + self.layers[index]( + query, + key, + value, + query_pos=query_pos, + key_pos=key_pos, + **kwargs, + ) + ) + + def forward( + self, + feature_maps: Union[torch.Tensor, List], + metas: dict, + ): + if isinstance(feature_maps, torch.Tensor): + feature_maps = [feature_maps] + batch_size = feature_maps[0].shape[0] + + # ========= get instance info ============ + if ( + self.sampler.dn_metas is not None + and self.sampler.dn_metas["dn_anchor"].shape[0] != batch_size + ): + self.sampler.dn_metas = None + ( + instance_feature, + anchor, + temp_instance_feature, + temp_anchor, + time_interval, + ) = self.instance_bank.get( + batch_size, metas, dn_metas=self.sampler.dn_metas + ) + + # ========= prepare for denosing training ============ + # 1. get dn metas: noisy-anchors and corresponding GT + # 2. concat learnable instances and noisy instances + # 3. get attention mask + attn_mask = None + dn_metas = None + temp_dn_reg_target = None + if self.training and hasattr(self.sampler, "get_dn_anchors"): + if self.gt_id_key in metas["img_metas"][0]: + gt_instance_id = [ + torch.from_numpy(x[self.gt_id_key]).cuda() + for x in metas["img_metas"] + ] + else: + gt_instance_id = None + dn_metas = self.sampler.get_dn_anchors( + metas[self.gt_cls_key], + metas[self.gt_reg_key], + gt_instance_id, + ) + if dn_metas is not None: + ( + dn_anchor, + dn_reg_target, + dn_cls_target, + dn_attn_mask, + valid_mask, + dn_id_target, + ) = dn_metas + num_dn_anchor = dn_anchor.shape[1] + if dn_anchor.shape[-1] != anchor.shape[-1]: + remain_state_dims = anchor.shape[-1] - dn_anchor.shape[-1] + dn_anchor = torch.cat( + [ + dn_anchor, + dn_anchor.new_zeros( + batch_size, num_dn_anchor, remain_state_dims + ), + ], + dim=-1, + ) + anchor = torch.cat([anchor, dn_anchor], dim=1) + instance_feature = torch.cat( + [ + instance_feature, + instance_feature.new_zeros( + batch_size, num_dn_anchor, instance_feature.shape[-1] + ), + ], + dim=1, + ) + num_instance = instance_feature.shape[1] + num_free_instance = num_instance - num_dn_anchor + attn_mask = anchor.new_ones( + (num_instance, num_instance), dtype=torch.bool + ) + attn_mask[:num_free_instance, :num_free_instance] = False + attn_mask[num_free_instance:, num_free_instance:] = dn_attn_mask + + anchor_embed = self.anchor_encoder(anchor) + if temp_anchor is not None: + temp_anchor_embed = self.anchor_encoder(temp_anchor) + else: + temp_anchor_embed = None + + # =========== temporal warmup (Block 0) ==================== + # Social self-attention among temporal queries before they are merged + # with current-frame detections. No image features used here. + # Always runs (even on first frame) so warmup params always receive + # gradients — avoids the need for find_unused_parameters=True. + if self.temporal_warmup_order: + if temp_instance_feature is not None: + # Temporal case: warm up the cached temporal features + w_feat = temp_instance_feature + w_anchor = temp_anchor + w_anchor_embed = temp_anchor_embed + is_temporal = True + else: + # First frame: warm up the first num_temp_instances current slots + num_ti = self.instance_bank.num_temp_instances + w_feat = instance_feature[:, :num_ti] + w_anchor = anchor[:, :num_ti] + w_anchor_embed = anchor_embed[:, :num_ti] + is_temporal = False + w_cls, w_qt, w_vis = None, None, None + for i, op in enumerate(self.temporal_warmup_order): + if self.warmup_layers[i] is None: + continue + if op == "gnn": + w_feat = self._gnn_with_layer( + self.warmup_layers[i], w_feat, w_anchor_embed + ) + elif op in ("norm", "ffn"): + w_feat = self.warmup_layers[i](w_feat) + elif op == "refine": + w_anchor, w_cls, w_qt, w_vis = self.warmup_layers[i]( + w_feat, + w_anchor, + w_anchor_embed, + time_interval=time_interval, + return_cls=True, + ) + w_anchor_embed = self.anchor_encoder(w_anchor) + if is_temporal: + temp_instance_feature = w_feat + temp_anchor = w_anchor + temp_anchor_embed = w_anchor_embed + else: + # Inject warmed first-frame features back so warmup params + # connect to the loss via the main decoder. + num_ti = self.instance_bank.num_temp_instances + instance_feature = torch.cat( + [w_feat, instance_feature[:, num_ti:]], dim=1 + ) + anchor = torch.cat([w_anchor, anchor[:, num_ti:]], dim=1) + anchor_embed = torch.cat( + [w_anchor_embed, anchor_embed[:, num_ti:]], dim=1 + ) + + # =================== forward the layers ==================== + prediction = [] + classification = [] + quality = [] + visibility = [] + num_warmup_preds = 0 + # If warmup produced a refine prediction on temporal instances, prepend it + # so it gets supervised like any other intermediate decoder stage. + # Pads non-temporal slots (num_ti:num_anchor) with initial anchor positions + # and near-zero cls logits so the sampler treats them as background. + # NOTE: do NOT gate this on is_temporal — the cls branch must always + # participate in the loss so DDP doesn't see unused parameters on + # first-frame batches (which have is_temporal=False). + if ( + self.temporal_warmup_order + and w_cls is not None + and dn_metas is None + ): + num_ti = self.instance_bank.num_temp_instances + num_anchor = self.instance_bank.num_anchor + warmup_pred = torch.cat( + [w_anchor, anchor[:, num_ti:num_anchor]], dim=1 + ) + warmup_cls = torch.cat( + [ + w_cls, + w_cls.new_full( + [batch_size, num_anchor - num_ti, w_cls.shape[-1]], -10.0 + ), + ], + dim=1, + ) + warmup_qt = ( + torch.cat( + [ + w_qt, + w_qt.new_zeros(batch_size, num_anchor - num_ti, w_qt.shape[-1]), + ], + dim=1, + ) + if w_qt is not None + else None + ) + warmup_vis = ( + torch.cat( + [ + w_vis, + w_vis.new_zeros(batch_size, num_anchor - num_ti, w_vis.shape[-1]), + ], + dim=1, + ) + if w_vis is not None + else None + ) + prediction.append(warmup_pred) + classification.append(warmup_cls) + quality.append(warmup_qt) + num_warmup_preds = 1 + visibility.append(warmup_vis) + num_main_decoder_refines = 0 + for i, op in enumerate(self.operation_order): + if self.layers[i] is None: + continue + elif op == "temp_gnn": + instance_feature = self.graph_model( + i, + instance_feature, + temp_instance_feature, + temp_instance_feature, + query_pos=anchor_embed, + key_pos=temp_anchor_embed, + attn_mask=attn_mask + if temp_instance_feature is None + else None, + ) + elif op == "gnn": + instance_feature = self.graph_model( + i, + instance_feature, + value=instance_feature, + query_pos=anchor_embed, + attn_mask=attn_mask, + ) + elif op == "norm" or op == "ffn": + instance_feature = self.layers[i](instance_feature) + elif op == "deformable": + instance_feature = self.layers[i]( + instance_feature, + anchor, + anchor_embed, + feature_maps, + metas, + ) + elif op == "refine": + anchor, cls, qt, vis = self.layers[i]( + instance_feature, + anchor, + anchor_embed, + time_interval=time_interval, + return_cls=True, + ) + prediction.append(anchor) + classification.append(cls) + quality.append(qt) + visibility.append(vis) + num_main_decoder_refines += 1 + if num_main_decoder_refines == self.num_single_frame_decoder: + instance_feature, anchor = self.instance_bank.update( + instance_feature, anchor, cls, + cached_feature_override=temp_instance_feature, + cached_anchor_override=temp_anchor, + ) + if ( + dn_metas is not None + and self.sampler.num_temp_dn_groups > 0 + and dn_id_target is not None + ): + ( + instance_feature, + anchor, + temp_dn_reg_target, + temp_dn_cls_target, + temp_valid_mask, + dn_id_target, + ) = self.sampler.update_dn( + instance_feature, + anchor, + dn_reg_target, + dn_cls_target, + valid_mask, + dn_id_target, + self.instance_bank.num_anchor, + self.instance_bank.mask, + ) + anchor_embed = self.anchor_encoder(anchor) + if ( + len(prediction) > self.num_single_frame_decoder + and temp_anchor_embed is not None + ): + temp_anchor_embed = anchor_embed[ + :, : self.instance_bank.num_temp_instances + ] + else: + raise NotImplementedError(f"{op} is not supported.") + + output = {} + + # split predictions of learnable instances and noisy instances + if dn_metas is not None: + dn_classification = [ + x[:, num_free_instance:] for x in classification + ] + classification = [x[:, :num_free_instance] for x in classification] + dn_prediction = [x[:, num_free_instance:] for x in prediction] + prediction = [x[:, :num_free_instance] for x in prediction] + quality = [ + x[:, :num_free_instance] if x is not None else None + for x in quality + ] + visibility = [ + x[:, :num_free_instance] if x is not None else None + for x in visibility + ] + output.update( + { + "dn_prediction": dn_prediction, + "dn_classification": dn_classification, + "dn_reg_target": dn_reg_target, + "dn_cls_target": dn_cls_target, + "dn_valid_mask": valid_mask, + } + ) + if temp_dn_reg_target is not None: + output.update( + { + "temp_dn_reg_target": temp_dn_reg_target, + "temp_dn_cls_target": temp_dn_cls_target, + "temp_dn_valid_mask": temp_valid_mask, + "dn_id_target": dn_id_target, + } + ) + dn_cls_target = temp_dn_cls_target + valid_mask = temp_valid_mask + dn_instance_feature = instance_feature[:, num_free_instance:] + dn_anchor = anchor[:, num_free_instance:] + instance_feature = instance_feature[:, :num_free_instance] + anchor_embed = anchor_embed[:, :num_free_instance] + anchor = anchor[:, :num_free_instance] + cls = cls[:, :num_free_instance] + + # cache dn_metas for temporal denoising + self.sampler.cache_dn( + dn_instance_feature, + dn_anchor, + dn_cls_target, + valid_mask, + dn_id_target, + ) + output.update( + { + "classification": classification, + "prediction": prediction, + "quality": quality, + "visibility": visibility, + "instance_feature": instance_feature, + "anchor_embed": anchor_embed, + "num_warmup_preds": num_warmup_preds, + } + ) + + # cache current instances for temporal modeling + self.instance_bank.cache( + instance_feature, anchor, cls, metas, feature_maps + ) + if self.with_instance_id: + instance_id = self.instance_bank.get_instance_id( + cls, anchor, self.decoder.score_threshold + ) + output["instance_id"] = instance_id + return output + + @force_fp32(apply_to=("model_outs")) + def loss(self, model_outs, data, feature_maps=None): + # ===================== prediction losses ====================== + cls_scores = model_outs["classification"] + reg_preds = model_outs["prediction"] + quality = model_outs["quality"] + vis_scores = model_outs.get("visibility", [None] * len(cls_scores)) + num_warmup_preds = model_outs.get("num_warmup_preds", 0) + output = {} + for decoder_idx, (cls, reg, qt, vis) in enumerate( + zip(cls_scores, reg_preds, quality, vis_scores) + ): + reg = reg[..., : len(self.reg_weights)] + if ( + self.warmup_supervise_all + and num_warmup_preds <= decoder_idx < num_warmup_preds + self.num_single_frame_decoder + and self.gt_visibility_key in data + ): + gt_vis = data[self.gt_visibility_key] + gt_cls = [l[v > 0] for l, v in zip(data[self.gt_cls_key], gt_vis)] + gt_reg = [b[v > 0] for b, v in zip(data[self.gt_reg_key], gt_vis)] + else: + gt_cls = data[self.gt_cls_key] + gt_reg = data[self.gt_reg_key] + cls_target, reg_target, reg_weights = self.sampler.sample( + cls, + reg, + gt_cls, + gt_reg, + ) + reg_target = reg_target[..., : len(self.reg_weights)] + reg_target_full = reg_target.clone() + mask = torch.logical_not(torch.all(reg_target == 0, dim=-1)) + mask_valid = mask.clone() + + num_pos = max( + reduce_mean(torch.sum(mask).to(dtype=reg.dtype)), 1.0 + ) + if self.cls_threshold_to_reg > 0: + threshold = self.cls_threshold_to_reg + mask = torch.logical_and( + mask, cls.max(dim=-1).values.sigmoid() > threshold + ) + + cls = cls.flatten(end_dim=1) + cls_target = cls_target.flatten(end_dim=1) + cls_loss = self.loss_cls(cls, cls_target, avg_factor=num_pos) + + mask = mask.reshape(-1) + reg_weights = reg_weights * reg.new_tensor(self.reg_weights) + reg_target = reg_target.flatten(end_dim=1)[mask] + reg = reg.flatten(end_dim=1)[mask] + reg_weights = reg_weights.flatten(end_dim=1)[mask] + reg_target = torch.where( + reg_target.isnan(), reg.new_tensor(0.0), reg_target + ) + cls_target = cls_target[mask] + if qt is not None: + qt = qt.flatten(end_dim=1)[mask] + + reg_loss = self.loss_reg( + reg, + reg_target, + weight=reg_weights, + avg_factor=num_pos, + prefix=f"{self.task_prefix}_", + suffix=f"_{decoder_idx}", + quality=qt, + cls_target=cls_target, + ) + + output[f"{self.task_prefix}_loss_cls_{decoder_idx}"] = cls_loss + output.update(reg_loss) + + # ---- visibility loss (only on matched / positive anchors) ---- + if ( + vis is not None + and self.loss_visibility is not None + and self.gt_visibility_key in data + ): + gt_vis_list = data[self.gt_visibility_key] + bs_v, num_pred_v = vis.shape[:2] + vis_target = vis.new_zeros(bs_v, num_pred_v) + for b_i, (pred_idx, target_idx) in enumerate( + self.sampler.indices + ): + if ( + pred_idx is not None + and len(pred_idx) > 0 + and len(gt_vis_list[b_i]) > 0 + ): + vis_target[b_i, pred_idx] = ( + gt_vis_list[b_i] + .to(vis.device) + .float()[target_idx] + ) + matched = mask_valid.reshape(-1) + vis_loss = self.loss_visibility( + vis.squeeze(-1).flatten(end_dim=1)[matched].unsqueeze(-1), + vis_target.flatten(end_dim=1)[matched].long(), + avg_factor=num_pos, + ) + output[ + f"{self.task_prefix}_loss_visibility_{decoder_idx}" + ] = vis_loss + + if "dn_prediction" not in model_outs: + return output + + # ===================== denoising losses ====================== + dn_cls_scores = model_outs["dn_classification"] + dn_reg_preds = model_outs["dn_prediction"] + + ( + dn_valid_mask, + dn_cls_target, + dn_reg_target, + dn_pos_mask, + reg_weights, + num_dn_pos, + ) = self.prepare_for_dn_loss(model_outs) + for decoder_idx, (cls, reg) in enumerate( + zip(dn_cls_scores, dn_reg_preds) + ): + if ( + "temp_dn_valid_mask" in model_outs + and decoder_idx == self.num_single_frame_decoder + ): + ( + dn_valid_mask, + dn_cls_target, + dn_reg_target, + dn_pos_mask, + reg_weights, + num_dn_pos, + ) = self.prepare_for_dn_loss(model_outs, prefix="temp_") + + cls_loss = self.loss_cls( + cls.flatten(end_dim=1)[dn_valid_mask], + dn_cls_target, + avg_factor=num_dn_pos, + ) + reg_loss = self.loss_reg( + reg.flatten(end_dim=1)[dn_valid_mask][dn_pos_mask][ + ..., : len(self.reg_weights) + ], + dn_reg_target, + avg_factor=num_dn_pos, + weight=reg_weights, + prefix=f"{self.task_prefix}_", + suffix=f"_dn_{decoder_idx}", + ) + output[f"{self.task_prefix}_loss_cls_dn_{decoder_idx}"] = cls_loss + output.update(reg_loss) + return output + + def prepare_for_dn_loss(self, model_outs, prefix=""): + dn_valid_mask = model_outs[f"{prefix}dn_valid_mask"].flatten(end_dim=1) + dn_cls_target = model_outs[f"{prefix}dn_cls_target"].flatten( + end_dim=1 + )[dn_valid_mask] + dn_reg_target = model_outs[f"{prefix}dn_reg_target"].flatten( + end_dim=1 + )[dn_valid_mask][..., : len(self.reg_weights)] + dn_pos_mask = dn_cls_target >= 0 + dn_reg_target = dn_reg_target[dn_pos_mask] + reg_weights = dn_reg_target.new_tensor(self.reg_weights)[None].tile( + dn_reg_target.shape[0], 1 + ) + num_dn_pos = max( + reduce_mean(torch.sum(dn_valid_mask).to(dtype=reg_weights.dtype)), + 1.0, + ) + return ( + dn_valid_mask, + dn_cls_target, + dn_reg_target, + dn_pos_mask, + reg_weights, + num_dn_pos, + ) + + @force_fp32(apply_to=("model_outs")) + def post_process(self, model_outs, output_idx=-1): + vis_list = model_outs.get("visibility") + # Only pass visibility to decode() when the head is active (not all-None). + # This keeps backward-compatibility with decoders that lack the param. + vis_kwarg = {} + if vis_list is not None and any(v is not None for v in vis_list): + vis_kwarg = {"visibility": vis_list} + return self.decoder.decode( + model_outs["classification"], + model_outs["prediction"], + instance_id=model_outs.get("instance_id"), + quality=model_outs.get("quality"), + output_idx=output_idx, + **vis_kwarg, + ) diff --git a/projects/mmdet3d_plugin/models/detection3d/losses.py b/projects/mmdet3d_plugin/models/detection3d/losses.py new file mode 100644 index 0000000..f4d6656 --- /dev/null +++ b/projects/mmdet3d_plugin/models/detection3d/losses.py @@ -0,0 +1,93 @@ +import torch +import torch.nn as nn + +from mmcv.utils import build_from_cfg +from mmdet.models.builder import LOSSES + +from projects.mmdet3d_plugin.core.box3d import * + + +@LOSSES.register_module() +class SparseBox3DLoss(nn.Module): + def __init__( + self, + loss_box, + loss_centerness=None, + loss_yawness=None, + cls_allow_reverse=None, + ): + super().__init__() + + def build(cfg, registry): + if cfg is None: + return None + return build_from_cfg(cfg, registry) + + self.loss_box = build(loss_box, LOSSES) + self.loss_cns = build(loss_centerness, LOSSES) + self.loss_yns = build(loss_yawness, LOSSES) + self.cls_allow_reverse = cls_allow_reverse + + def forward( + self, + box, + box_target, + weight=None, + avg_factor=None, + prefix="", + suffix="", + quality=None, + cls_target=None, + **kwargs, + ): + # Some categories do not distinguish between positive and negative + # directions. For example, barrier in nuScenes dataset. + if self.cls_allow_reverse is not None and cls_target is not None: + if_reverse = ( + torch.nn.functional.cosine_similarity( + box_target[..., [SIN_YAW, COS_YAW]], + box[..., [SIN_YAW, COS_YAW]], + dim=-1, + ) + < 0 + ) + if_reverse = ( + torch.isin( + cls_target, cls_target.new_tensor(self.cls_allow_reverse) + ) + & if_reverse + ) + box_target[..., [SIN_YAW, COS_YAW]] = torch.where( + if_reverse[..., None], + -box_target[..., [SIN_YAW, COS_YAW]], + box_target[..., [SIN_YAW, COS_YAW]], + ) + + output = {} + box_loss = self.loss_box( + box, box_target, weight=weight, avg_factor=avg_factor + ) + output[f"{prefix}loss_box{suffix}"] = box_loss + + if quality is not None: + cns = quality[..., CNS] + yns = quality[..., YNS].sigmoid() + cns_target = torch.norm( + box_target[..., [X, Y, Z]] - box[..., [X, Y, Z]], p=2, dim=-1 + ) + cns_target = torch.exp(-cns_target) + cns_loss = self.loss_cns(cns, cns_target, avg_factor=avg_factor) + output[f"{prefix}loss_cns{suffix}"] = cns_loss + + yns_target = ( + torch.nn.functional.cosine_similarity( + box_target[..., [SIN_YAW, COS_YAW]], + box[..., [SIN_YAW, COS_YAW]], + dim=-1, + ) + > 0 + ) + yns_target = yns_target.float() + yns_loss = self.loss_yns(yns, yns_target, avg_factor=avg_factor) + output[f"{prefix}loss_yns{suffix}"] = yns_loss + return output diff --git a/projects/mmdet3d_plugin/models/detection3d/target.py b/projects/mmdet3d_plugin/models/detection3d/target.py new file mode 100644 index 0000000..fa284e7 --- /dev/null +++ b/projects/mmdet3d_plugin/models/detection3d/target.py @@ -0,0 +1,437 @@ +import torch +import numpy as np +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from mmdet.core.bbox.builder import BBOX_SAMPLERS + +from projects.mmdet3d_plugin.core.box3d import * +from ..base_target import BaseTargetWithDenoising + + +__all__ = ["SparseBox3DTarget"] + + +@BBOX_SAMPLERS.register_module() +class SparseBox3DTarget(BaseTargetWithDenoising): + def __init__( + self, + cls_weight=2.0, + alpha=0.25, + gamma=2, + eps=1e-12, + box_weight=0.25, + reg_weights=None, + cls_wise_reg_weights=None, + num_dn_groups=0, + dn_noise_scale=0.5, + max_dn_gt=32, + add_neg_dn=True, + num_temp_dn_groups=0, + ): + super(SparseBox3DTarget, self).__init__( + num_dn_groups, num_temp_dn_groups + ) + self.cls_weight = cls_weight + self.box_weight = box_weight + self.alpha = alpha + self.gamma = gamma + self.eps = eps + self.reg_weights = reg_weights + if self.reg_weights is None: + self.reg_weights = [1.0] * 8 + [0.0] * 2 + self.cls_wise_reg_weights = cls_wise_reg_weights + self.dn_noise_scale = dn_noise_scale + self.max_dn_gt = max_dn_gt + self.add_neg_dn = add_neg_dn + + def encode_reg_target(self, box_target, device=None): + outputs = [] + for box in box_target: + output = torch.cat( + [ + box[..., [X, Y, Z]], + box[..., [W, L, H]].log(), + torch.sin(box[..., YAW]).unsqueeze(-1), + torch.cos(box[..., YAW]).unsqueeze(-1), + box[..., YAW + 1 :], + ], + dim=-1, + ) + if device is not None: + output = output.to(device=device) + outputs.append(output) + return outputs + + def sample( + self, + cls_pred, + box_pred, + cls_target, + box_target, + ): + bs, num_pred, num_cls = cls_pred.shape + + cls_cost = self._cls_cost(cls_pred, cls_target) + + box_target = self.encode_reg_target(box_target, box_pred.device) + + instance_reg_weights = [] + for i in range(len(box_target)): + weights = torch.logical_not(box_target[i].isnan()).to( + dtype=box_target[i].dtype + ) + if self.cls_wise_reg_weights is not None: + for cls, weight in self.cls_wise_reg_weights.items(): + weights = torch.where( + (cls_target[i] == cls)[:, None], + weights.new_tensor(weight), + weights, + ) + instance_reg_weights.append(weights) + box_cost = self._box_cost(box_pred, box_target, instance_reg_weights) + + indices = [] + for i in range(bs): + if cls_cost[i] is not None and box_cost[i] is not None: + cost = (cls_cost[i] + box_cost[i]).detach().cpu().numpy() + cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost) + assign = linear_sum_assignment(cost) + indices.append( + [cls_pred.new_tensor(x, dtype=torch.int64) for x in assign] + ) + else: + indices.append([None, None]) + + output_cls_target = ( + cls_target[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls + ) + output_box_target = box_pred.new_zeros(box_pred.shape) + output_reg_weights = box_pred.new_zeros(box_pred.shape) + for i, (pred_idx, target_idx) in enumerate(indices): + if len(cls_target[i]) == 0: + continue + output_cls_target[i, pred_idx] = cls_target[i][target_idx] + output_box_target[i, pred_idx] = box_target[i][target_idx] + output_reg_weights[i, pred_idx] = instance_reg_weights[i][ + target_idx + ] + self.indices = indices + return output_cls_target, output_box_target, output_reg_weights + + def _cls_cost(self, cls_pred, cls_target): + bs = cls_pred.shape[0] + cls_pred = cls_pred.sigmoid() + cost = [] + for i in range(bs): + if len(cls_target[i]) > 0: + neg_cost = ( + -(1 - cls_pred[i] + self.eps).log() + * (1 - self.alpha) + * cls_pred[i].pow(self.gamma) + ) + pos_cost = ( + -(cls_pred[i] + self.eps).log() + * self.alpha + * (1 - cls_pred[i]).pow(self.gamma) + ) + cost.append( + (pos_cost[:, cls_target[i]] - neg_cost[:, cls_target[i]]) + * self.cls_weight + ) + else: + cost.append(None) + return cost + + def _box_cost(self, box_pred, box_target, instance_reg_weights): + bs = box_pred.shape[0] + cost = [] + for i in range(bs): + if len(box_target[i]) > 0: + cost.append( + torch.sum( + torch.abs(box_pred[i, :, None] - box_target[i][None]) + * instance_reg_weights[i][None] + * box_pred.new_tensor(self.reg_weights), + dim=-1, + ) + * self.box_weight + ) + else: + cost.append(None) + return cost + + def get_dn_anchors(self, cls_target, box_target, gt_instance_id=None): + if self.num_dn_groups <= 0: + return None + if self.num_temp_dn_groups <= 0: + gt_instance_id = None + + if self.max_dn_gt > 0: + cls_target = [x[: self.max_dn_gt] for x in cls_target] + box_target = [x[: self.max_dn_gt] for x in box_target] + if gt_instance_id is not None: + gt_instance_id = [x[: self.max_dn_gt] for x in gt_instance_id] + + max_dn_gt = max([len(x) for x in cls_target]) + if max_dn_gt == 0: + return None + cls_target = torch.stack( + [ + F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1) + for x in cls_target + ] + ) + box_target = self.encode_reg_target(box_target, cls_target.device) + box_target = torch.stack( + [F.pad(x, (0, 0, 0, max_dn_gt - x.shape[0])) for x in box_target] + ) + box_target = torch.where( + cls_target[..., None] == -1, box_target.new_tensor(0), box_target + ) + if gt_instance_id is not None: + gt_instance_id = torch.stack( + [ + F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1) + for x in gt_instance_id + ] + ) + + bs, num_gt, state_dims = box_target.shape + if self.num_dn_groups > 1: + cls_target = cls_target.tile(self.num_dn_groups, 1) + box_target = box_target.tile(self.num_dn_groups, 1, 1) + if gt_instance_id is not None: + gt_instance_id = gt_instance_id.tile(self.num_dn_groups, 1) + + noise = torch.rand_like(box_target) * 2 - 1 + noise *= box_target.new_tensor(self.dn_noise_scale) + dn_anchor = box_target + noise + if self.add_neg_dn: + noise_neg = torch.rand_like(box_target) + 1 + flag = torch.where( + torch.rand_like(box_target) > 0.5, + noise_neg.new_tensor(1), + noise_neg.new_tensor(-1), + ) + noise_neg *= flag + noise_neg *= box_target.new_tensor(self.dn_noise_scale) + dn_anchor = torch.cat([dn_anchor, box_target + noise_neg], dim=1) + num_gt *= 2 + + box_cost = self._box_cost( + dn_anchor, box_target, torch.ones_like(box_target) + ) + dn_box_target = torch.zeros_like(dn_anchor) + dn_cls_target = -torch.ones_like(cls_target) * 3 + if gt_instance_id is not None: + dn_id_target = -torch.ones_like(gt_instance_id) + if self.add_neg_dn: + dn_cls_target = torch.cat([dn_cls_target, dn_cls_target], dim=1) + if gt_instance_id is not None: + dn_id_target = torch.cat([dn_id_target, dn_id_target], dim=1) + + for i in range(dn_anchor.shape[0]): + cost = box_cost[i].cpu().numpy() + anchor_idx, gt_idx = linear_sum_assignment(cost) + anchor_idx = dn_anchor.new_tensor(anchor_idx, dtype=torch.int64) + gt_idx = dn_anchor.new_tensor(gt_idx, dtype=torch.int64) + dn_box_target[i, anchor_idx] = box_target[i, gt_idx] + dn_cls_target[i, anchor_idx] = cls_target[i, gt_idx] + if gt_instance_id is not None: + dn_id_target[i, anchor_idx] = gt_instance_id[i, gt_idx] + dn_anchor = ( + dn_anchor.reshape(self.num_dn_groups, bs, num_gt, state_dims) + .permute(1, 0, 2, 3) + .flatten(1, 2) + ) + dn_box_target = ( + dn_box_target.reshape(self.num_dn_groups, bs, num_gt, state_dims) + .permute(1, 0, 2, 3) + .flatten(1, 2) + ) + dn_cls_target = ( + dn_cls_target.reshape(self.num_dn_groups, bs, num_gt) + .permute(1, 0, 2) + .flatten(1) + ) + if gt_instance_id is not None: + dn_id_target = ( + dn_id_target.reshape(self.num_dn_groups, bs, num_gt) + .permute(1, 0, 2) + .flatten(1) + ) + else: + dn_id_target = None + valid_mask = dn_cls_target >= 0 + if self.add_neg_dn: + cls_target = ( + torch.cat([cls_target, cls_target], dim=1) + .reshape(self.num_dn_groups, bs, num_gt) + .permute(1, 0, 2) + .flatten(1) + ) + valid_mask = torch.logical_or( + valid_mask, ((cls_target >= 0) & (dn_cls_target == -3)) + ) # valid denotes the items is not from pad. + attn_mask = dn_box_target.new_ones( + num_gt * self.num_dn_groups, num_gt * self.num_dn_groups + ) + for i in range(self.num_dn_groups): + start = num_gt * i + end = start + num_gt + attn_mask[start:end, start:end] = 0 + attn_mask = attn_mask == 1 + dn_cls_target = dn_cls_target.long() + return ( + dn_anchor, + dn_box_target, + dn_cls_target, + attn_mask, + valid_mask, + dn_id_target, + ) + + def update_dn( + self, + instance_feature, + anchor, + dn_reg_target, + dn_cls_target, + valid_mask, + dn_id_target, + num_noraml_anchor, + temporal_valid_mask, + ): + bs, num_anchor = instance_feature.shape[:2] + if temporal_valid_mask is None: + self.dn_metas = None + if self.dn_metas is None or num_noraml_anchor >= num_anchor: + return ( + instance_feature, + anchor, + dn_reg_target, + dn_cls_target, + valid_mask, + dn_id_target, + ) + + # split instance_feature and anchor into non-dn and dn + num_dn = num_anchor - num_noraml_anchor + dn_instance_feature = instance_feature[:, -num_dn:] + dn_anchor = anchor[:, -num_dn:] + instance_feature = instance_feature[:, :num_noraml_anchor] + anchor = anchor[:, :num_noraml_anchor] + + # reshape all dn metas from (bs,num_all_dn,xxx) + # to (bs, dn_group, num_dn_per_group, xxx) + num_dn_groups = self.num_dn_groups + num_dn = num_dn // num_dn_groups + dn_feat = dn_instance_feature.reshape(bs, num_dn_groups, num_dn, -1) + dn_anchor = dn_anchor.reshape(bs, num_dn_groups, num_dn, -1) + dn_reg_target = dn_reg_target.reshape(bs, num_dn_groups, num_dn, -1) + dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_dn) + valid_mask = valid_mask.reshape(bs, num_dn_groups, num_dn) + if dn_id_target is not None: + dn_id = dn_id_target.reshape(bs, num_dn_groups, num_dn) + + # update temp_dn_metas by instance_id + temp_dn_feat = self.dn_metas["dn_instance_feature"] + _, num_temp_dn_groups, num_temp_dn = temp_dn_feat.shape[:3] + temp_dn_id = self.dn_metas["dn_id_target"] + + # bs, num_temp_dn_groups, num_temp_dn, num_dn + match = temp_dn_id[..., None] == dn_id[:, :num_temp_dn_groups, None] + temp_reg_target = ( + match[..., None] * dn_reg_target[:, :num_temp_dn_groups, None] + ).sum(dim=3) + temp_cls_target = torch.where( + torch.all(torch.logical_not(match), dim=-1), + self.dn_metas["dn_cls_target"].new_tensor(-1), + self.dn_metas["dn_cls_target"], + ) + temp_valid_mask = self.dn_metas["valid_mask"] + temp_dn_anchor = self.dn_metas["dn_anchor"] + + # handle the misalignment the length of temp_dn to dn caused by the + # change of num_gt, then concat the temp_dn and dn + temp_dn_metas = [ + temp_dn_feat, + temp_dn_anchor, + temp_reg_target, + temp_cls_target, + temp_valid_mask, + temp_dn_id, + ] + dn_metas = [ + dn_feat, + dn_anchor, + dn_reg_target, + dn_cls_target, + valid_mask, + dn_id, + ] + output = [] + for i, (temp_meta, meta) in enumerate(zip(temp_dn_metas, dn_metas)): + if num_temp_dn < num_dn: + pad = (0, num_dn - num_temp_dn) + if temp_meta.dim() == 4: + pad = (0, 0) + pad + else: + assert temp_meta.dim() == 3 + temp_meta = F.pad(temp_meta, pad, value=0) + else: + temp_meta = temp_meta[:, :, :num_dn] + mask = temporal_valid_mask[:, None, None] + if meta.dim() == 4: + mask = mask.unsqueeze(dim=-1) + temp_meta = torch.where( + mask, temp_meta, meta[:, :num_temp_dn_groups] + ) + meta = torch.cat([temp_meta, meta[:, num_temp_dn_groups:]], dim=1) + meta = meta.flatten(1, 2) + output.append(meta) + output[0] = torch.cat([instance_feature, output[0]], dim=1) + output[1] = torch.cat([anchor, output[1]], dim=1) + return output + + def cache_dn( + self, + dn_instance_feature, + dn_anchor, + dn_cls_target, + valid_mask, + dn_id_target, + ): + if self.num_temp_dn_groups < 0: + return + num_dn_groups = self.num_dn_groups + bs, num_dn = dn_instance_feature.shape[:2] + num_temp_dn = num_dn // num_dn_groups + temp_group_mask = ( + torch.randperm(num_dn_groups) < self.num_temp_dn_groups + ) + temp_group_mask = temp_group_mask.to(device=dn_anchor.device) + dn_instance_feature = dn_instance_feature.detach().reshape( + bs, num_dn_groups, num_temp_dn, -1 + )[:, temp_group_mask] + dn_anchor = dn_anchor.detach().reshape( + bs, num_dn_groups, num_temp_dn, -1 + )[:, temp_group_mask] + dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_temp_dn)[ + :, temp_group_mask + ] + valid_mask = valid_mask.reshape(bs, num_dn_groups, num_temp_dn)[ + :, temp_group_mask + ] + if dn_id_target is not None: + dn_id_target = dn_id_target.reshape( + bs, num_dn_groups, num_temp_dn + )[:, temp_group_mask] + self.dn_metas = dict( + dn_instance_feature=dn_instance_feature, + dn_anchor=dn_anchor, + dn_cls_target=dn_cls_target, + valid_mask=valid_mask, + dn_id_target=dn_id_target, + ) diff --git a/projects/mmdet3d_plugin/models/grid_mask.py b/projects/mmdet3d_plugin/models/grid_mask.py new file mode 100644 index 0000000..ead0caa --- /dev/null +++ b/projects/mmdet3d_plugin/models/grid_mask.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn +import numpy as np +from PIL import Image + + +class Grid(object): + def __init__( + self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0 + ): + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.st_prob = prob + self.prob = prob + + def set_prob(self, epoch, max_epoch): + self.prob = self.st_prob * epoch / max_epoch + + def __call__(self, img, label): + if np.random.rand() > self.prob: + return img, label + h = img.size(1) + w = img.size(2) + self.d1 = 2 + self.d2 = min(h, w) + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(self.d1, self.d2) + if self.ratio == 1: + self.l = np.random.randint(1, d) + else: + self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + self.l, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + self.l, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[ + (hh - h) // 2 : (hh - h) // 2 + h, + (ww - w) // 2 : (ww - w) // 2 + w, + ] + + mask = torch.from_numpy(mask).float() + if self.mode == 1: + mask = 1 - mask + + mask = mask.expand_as(img) + if self.offset: + offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float() + offset = (1 - mask) * offset + img = img * mask + offset + else: + img = img * mask + + return img, label + + +class GridMask(nn.Module): + def __init__( + self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0 + ): + super(GridMask, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.st_prob = prob + self.prob = prob + + def set_prob(self, epoch, max_epoch): + self.prob = self.st_prob * epoch / max_epoch # + 1.#0.5 + + def forward(self, x): + if np.random.rand() > self.prob or not self.training: + return x + n, c, h, w = x.size() + x = x.view(-1, h, w) + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(2, h) + self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + self.l, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + self.l, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[ + (hh - h) // 2 : (hh - h) // 2 + h, + (ww - w) // 2 : (ww - w) // 2 + w, + ] + + mask = torch.from_numpy(mask.copy()).float().cuda() + if self.mode == 1: + mask = 1 - mask + mask = mask.expand_as(x) + if self.offset: + offset = ( + torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)) + .float() + .cuda() + ) + x = x * mask + offset * (1 - mask) + else: + x = x * mask + + return x.view(n, c, h, w) diff --git a/projects/mmdet3d_plugin/models/gt_sparse_drive_head.py b/projects/mmdet3d_plugin/models/gt_sparse_drive_head.py new file mode 100644 index 0000000..bdb80d9 --- /dev/null +++ b/projects/mmdet3d_plugin/models/gt_sparse_drive_head.py @@ -0,0 +1,378 @@ +from typing import List + +import torch + +from mmcv.runner import BaseModule +from mmdet.models import HEADS, build_head + +from projects.mmdet3d_plugin.core.box3d import SIN_YAW, COS_YAW, YAW + + +@HEADS.register_module() +class GTSparseDriveHead(BaseModule): + """SparseDrive head that uses GT boxes and optionally GT map as oracle inputs. + + Bypasses the detection transformer and feeds GT bounding boxes (and + optionally GT map polylines) directly into the motion/planning head + for oracle prediction evaluation. + + The det_head config is still required to provide: + - anchor_encoder (SparseBox3DEncoder) + - instance_bank (InstanceBank for temporal mask and anchor_handler) + - sampler (for the indices interface expected by MotionTarget) + + If map_head config is provided, its anchor_encoder and instance_bank + are used to encode GT map polylines into the map_output structure + expected by MotionPlanningHead (enables cross_gnn map attention). + + Args: + task_config: Task flags (with_det, with_map, with_motion_plan). + det_head: Config for Sparse4DHead (used for sub-modules only). + map_head: Config for map Sparse4DHead (used for sub-modules only). + If provided, GT map is injected into motion/planning head. + motion_plan_head: Config for MotionPlanningHead. + num_classes: Number of detection classes. + """ + + def __init__( + self, + task_config: dict, + det_head: dict = None, + map_head: dict = None, + motion_plan_head: dict = None, + num_classes: int = 10, + num_map_classes: int = 3, + init_cfg=None, + **kwargs, + ): + super(GTSparseDriveHead, self).__init__(init_cfg) + self.task_config = task_config + self.num_classes = num_classes + + assert det_head is not None, ( + "det_head config is required to provide anchor_encoder " + "and instance_bank sub-modules." + ) + self.det_head = build_head(det_head) + # Freeze all det_head parameters: we only use its sub-modules + # (anchor_encoder, instance_bank) as non-trainable components. + # This prevents DDP from complaining about unused parameters. + for p in self.det_head.parameters(): + p.requires_grad_(False) + + self.num_map_classes = num_map_classes + if map_head is not None: + self.map_head = build_head(map_head) + # Freeze map_head parameters: only anchor_encoder and + # instance_bank sub-modules are used (non-trainable). + for p in self.map_head.parameters(): + p.requires_grad_(False) + + assert motion_plan_head is not None + self.motion_plan_head = build_head(motion_plan_head) + + def init_weights(self): + self.det_head.init_weights() + if hasattr(self, "map_head"): + self.map_head.init_weights() + self.motion_plan_head.init_weights() + + # ------------------------------------------------------------------ # + # Forward + # ------------------------------------------------------------------ # + + def forward(self, feature_maps, metas: dict): + batch_size = len(metas["img_metas"]) + device = ( + feature_maps[0].device + if isinstance(feature_maps[0], torch.Tensor) + else feature_maps[0][0].device + ) + + # 1. Update instance_bank temporal mask. + # We discard the returned bank features; we only need self.mask. + self.det_head.instance_bank.get(batch_size, metas) + + # 2. Build det_output from GT boxes. + det_output = self._build_gt_det_output(metas, batch_size, feature_maps) + + # 3. Cache GT features/anchors in the bank for temporal tracking. + self.det_head.instance_bank.cache( + det_output["instance_feature"], + det_output["prediction"][-1], + det_output["classification"][-1], + metas, + feature_maps, + ) + + # 4. Build GT map_output if map_head sub-modules are available. + if hasattr(self, "map_head"): + map_output = self._build_gt_map_output(metas, batch_size, device) + else: + map_output = None + + # 5. Forward motion/planning head. + motion_output, planning_output = self.motion_plan_head( + det_output, + map_output, + feature_maps, + metas, + self.det_head.anchor_encoder, + self.det_head.instance_bank.mask, + self.det_head.instance_bank.anchor_handler, + ) + + return det_output, map_output, motion_output, planning_output + + # ------------------------------------------------------------------ # + # GT det_output construction helpers + # ------------------------------------------------------------------ # + + @staticmethod + def _encode_gt_boxes(boxes: torch.Tensor) -> torch.Tensor: + """Convert decoded 9-dim GT boxes to encoded 11-dim anchor format. + + Input: (N, >=7) [..., x, y, z, w, l, h, yaw, vx, vy] + Output: (N, 11) [x, y, z, log_w, log_l, log_h, + sin_yaw, cos_yaw, vx, vy, 0] + """ + xyz = boxes[:, :3] + wlh = boxes[:, 3:6].clamp(min=1e-3).log() + sin_yaw = torch.sin(boxes[:, YAW : YAW + 1]) + cos_yaw = torch.cos(boxes[:, YAW : YAW + 1]) + vel = boxes[:, 7:9] if boxes.shape[-1] >= 9 else boxes.new_zeros(len(boxes), 2) + vz = boxes.new_zeros(len(boxes), 1) + return torch.cat([xyz, wlh, sin_yaw, cos_yaw, vel, vz], dim=-1) + + def _build_gt_det_output( + self, metas: dict, batch_size: int, feature_maps + ) -> dict: + """Build the det_output dict populated with GT boxes.""" + if isinstance(feature_maps[0], torch.Tensor): + device = feature_maps[0].device + else: + device = feature_maps[0][0].device + + num_anchor = self.det_head.instance_bank.num_anchor + embed_dims = self.det_head.instance_bank.embed_dims + + gt_bboxes = metas["gt_bboxes_3d"] # list[Tensor(N_i, 9)] + gt_labels = metas["gt_labels_3d"] # list[Tensor(N_i,)] + + anchors = torch.zeros(batch_size, num_anchor, 11, device=device) + # Very-negative logits → near-zero confidence after sigmoid (padding). + cls_logits = anchors.new_full( + (batch_size, num_anchor, self.num_classes), -100.0 + ) + + for i in range(batch_size): + bboxes_i = gt_bboxes[i] + labels_i = gt_labels[i] + + if not isinstance(bboxes_i, torch.Tensor): + bboxes_i = torch.tensor( + bboxes_i, device=device, dtype=torch.float32 + ) + else: + bboxes_i = bboxes_i.to(device=device, dtype=torch.float32) + + if not isinstance(labels_i, torch.Tensor): + labels_i = torch.tensor( + labels_i, device=device, dtype=torch.long + ) + else: + labels_i = labels_i.to(device=device) + + N_i = len(bboxes_i) + if N_i == 0: + continue + N_i = min(N_i, num_anchor) + bboxes_i = bboxes_i[:N_i] + labels_i = labels_i[:N_i] + + anchors[i, :N_i] = self._encode_gt_boxes(bboxes_i) + + # High-positive logit at GT class, very-negative elsewhere. + cls_logits[i, :N_i] = -100.0 + cls_logits[i, torch.arange(N_i, device=device), labels_i] = 100.0 + + # Anchor embeddings from the det_head's encoder. + anchor_embed = self.det_head.anchor_encoder(anchors) + + # Instance features initialised to zero; the motion GNN refines them. + instance_feature = torch.zeros( + batch_size, num_anchor, embed_dims, device=device + ) + + # GT instance IDs for temporal tracking in InstanceQueue. + instance_id = self._get_gt_instance_ids( + metas, batch_size, num_anchor, device + ) + + return { + "instance_feature": instance_feature, + "anchor_embed": anchor_embed, + "classification": [cls_logits], + "prediction": [anchors], + "quality": [None], + "instance_id": instance_id, + } + + def _build_gt_map_output( + self, metas: dict, batch_size: int, device + ) -> dict: + """Build the map_output dict populated with GT map polylines. + + GT map pts are encoded via the map_head's anchor_encoder into the + same anchor_embed space that MotionPlanningHead's cross_gnn expects. + Classification logits are set to +100 at the GT class so that the + confidence-based top-k selection in MotionPlanningHead picks the + true map elements. + + Args: + metas: Batch metas containing 'gt_map_labels' and 'gt_map_pts'. + gt_map_labels: list[Tensor(M_i,)] + gt_map_pts: list[Tensor(M_i, num_sample, 2)] (test) + or list[Tensor(M_i, num_perms, num_sample, 2)] + (train, when VectorizeMap has permute=True). + """ + num_anchor = self.map_head.instance_bank.num_anchor # 100 + embed_dims = self.map_head.instance_bank.embed_dims # 256 + num_sample_x2 = self.map_head.anchor_encoder.input_dims # num_sample * 2 + + gt_map_labels = metas["gt_map_labels"] # list[Tensor(M_i,)] + gt_map_pts = metas["gt_map_pts"] # list[Tensor(...)] + + predictions = torch.zeros( + batch_size, num_anchor, num_sample_x2, device=device + ) + cls_logits = torch.full( + (batch_size, num_anchor, self.num_map_classes), -100.0, device=device + ) + + for i in range(batch_size): + map_pts_i = gt_map_pts[i] + map_labels_i = gt_map_labels[i] + + if not isinstance(map_pts_i, torch.Tensor): + map_pts_i = torch.tensor( + map_pts_i, device=device, dtype=torch.float32 + ) + else: + map_pts_i = map_pts_i.to(device=device, dtype=torch.float32) + + if not isinstance(map_labels_i, torch.Tensor): + map_labels_i = torch.tensor( + map_labels_i, device=device, dtype=torch.long + ) + else: + map_labels_i = map_labels_i.to(device=device) + + M_i = len(map_pts_i) + if M_i == 0: + continue + M_i = min(M_i, num_anchor) + map_pts_i = map_pts_i[:M_i] + map_labels_i = map_labels_i[:M_i] + + # Train pipeline uses permute=True: (M, num_perms, num_sample, 2). + # Take the first permutation (canonical polyline direction). + if map_pts_i.dim() == 4: + map_pts_i = map_pts_i[:, 0] # (M, num_sample, 2) + + predictions[i, :M_i] = map_pts_i.reshape(M_i, -1) + cls_logits[i, :M_i] = -100.0 + cls_logits[i, torch.arange(M_i, device=device), map_labels_i] = 100.0 + + # Encode GT map point coordinates into positional embeddings. + anchor_embed = self.map_head.anchor_encoder(predictions) + + # Instance features initialised to zero; cross_gnn will attend to + # anchor_embed (position) rather than learned instance content. + instance_feature = torch.zeros( + batch_size, num_anchor, embed_dims, device=device + ) + + return { + "instance_feature": instance_feature, + "anchor_embed": anchor_embed, + "classification": [cls_logits], + "prediction": [predictions], + } + + @staticmethod + def _get_gt_instance_ids( + metas: dict, + batch_size: int, + num_anchor: int, + device, + ) -> torch.Tensor: + """Read GT instance IDs from metas, padded to (bs, num_anchor).""" + instance_id = torch.full( + (batch_size, num_anchor), -1, dtype=torch.long, device=device + ) + for i, img_meta in enumerate(metas["img_metas"]): + gt_id = img_meta.get("instance_id", None) + if gt_id is None: + continue + if not isinstance(gt_id, torch.Tensor): + gt_id = torch.tensor(gt_id, dtype=torch.long, device=device) + else: + gt_id = gt_id.to(device=device) + N_i = min(len(gt_id), num_anchor) + instance_id[i, :N_i] = gt_id[:N_i] + return instance_id + + # ------------------------------------------------------------------ # + # Loss + # ------------------------------------------------------------------ # + + def loss(self, model_outs, data): + _, _, motion_output, planning_output = model_outs + + motion_loss_cache = dict( + indices=self._build_identity_indices(data), + ) + return self.motion_plan_head.loss( + motion_output, planning_output, data, motion_loss_cache + ) + + def _build_identity_indices(self, data) -> List: + """Identity matching: pred anchor i → GT agent i for each batch item. + + The motion sampler uses these indices to assign GT future trajectories + to predicted agent slots. With GT-as-detection, agent i directly + corresponds to GT agent i, so both pred_idx and target_idx are [0..N_i). + """ + gt_labels = data["gt_labels_3d"] # list of (N_i,) tensors + device = next(self.parameters()).device + indices = [] + for i in range(len(gt_labels)): + N_i = len(gt_labels[i]) + if N_i == 0: + indices.append([None, None]) + else: + idx = torch.arange(N_i, device=device, dtype=torch.long) + indices.append([idx, idx]) + return indices + + # ------------------------------------------------------------------ # + # Post-process + # ------------------------------------------------------------------ # + + def post_process(self, model_outs, data): + det_output, _, motion_output, planning_output = model_outs + + det_result = self.det_head.post_process(det_output) + motion_result, planning_result = self.motion_plan_head.post_process( + det_output, motion_output, planning_output, data + ) + + batch_size = len(motion_result) + results = [dict() for _ in range(batch_size)] + for i in range(batch_size): + results[i].update(det_result[i]) + results[i].update(motion_result[i]) + results[i].update(planning_result[i]) + + return results diff --git a/projects/mmdet3d_plugin/models/instance_bank.py b/projects/mmdet3d_plugin/models/instance_bank.py new file mode 100644 index 0000000..49a0353 --- /dev/null +++ b/projects/mmdet3d_plugin/models/instance_bank.py @@ -0,0 +1,265 @@ +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +from mmcv.utils import build_from_cfg +from mmcv.cnn.bricks.registry import PLUGIN_LAYERS + +__all__ = ["InstanceBank"] + + +def topk(confidence, k, *inputs): + bs, N = confidence.shape[:2] + confidence, indices = torch.topk(confidence, k, dim=1) + indices = ( + indices + torch.arange(bs, device=indices.device)[:, None] * N + ).reshape(-1) + outputs = [] + for input in inputs: + outputs.append(input.flatten(end_dim=1)[indices].reshape(bs, k, -1)) + return confidence, outputs + + +@PLUGIN_LAYERS.register_module() +class InstanceBank(nn.Module): + def __init__( + self, + num_anchor, + embed_dims, + anchor, + anchor_handler=None, + num_temp_instances=0, + default_time_interval=0.5, + confidence_decay=0.6, + anchor_grad=True, + feat_grad=True, + max_time_interval=2, + ): + super(InstanceBank, self).__init__() + self.embed_dims = embed_dims + self.num_temp_instances = num_temp_instances + self.default_time_interval = default_time_interval + self.confidence_decay = confidence_decay + self.max_time_interval = max_time_interval + + if anchor_handler is not None: + anchor_handler = build_from_cfg(anchor_handler, PLUGIN_LAYERS) + assert hasattr(anchor_handler, "anchor_projection") + self.anchor_handler = anchor_handler + if isinstance(anchor, str): + anchor = np.load(anchor) + elif isinstance(anchor, (list, tuple)): + anchor = np.array(anchor) + if len(anchor.shape) == 3: # for map + anchor = anchor.reshape(anchor.shape[0], -1) + self.num_anchor = min(len(anchor), num_anchor) + anchor = anchor[:num_anchor] + self.anchor = nn.Parameter( + torch.tensor(anchor, dtype=torch.float32), + requires_grad=anchor_grad, + ) + self.anchor_init = anchor + self.instance_feature = nn.Parameter( + torch.zeros([self.anchor.shape[0], self.embed_dims]), + requires_grad=feat_grad, + ) + self.reset() + + def init_weight(self): + self.anchor.data = self.anchor.data.new_tensor(self.anchor_init) + if self.instance_feature.requires_grad: + torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1) + + def reset(self): + self.cached_feature = None + self.cached_anchor = None + self.metas = None + self.mask = None + self.confidence = None + self.temp_confidence = None + self.instance_id = None + self.prev_id = 0 + + def get(self, batch_size, metas=None, dn_metas=None): + instance_feature = torch.tile( + self.instance_feature[None], (batch_size, 1, 1) + ) + anchor = torch.tile(self.anchor[None], (batch_size, 1, 1)) + + if ( + self.cached_anchor is not None + and batch_size == self.cached_anchor.shape[0] + ): + history_time = self.metas["timestamp"] + time_interval = metas["timestamp"] - history_time + time_interval = time_interval.to(dtype=instance_feature.dtype) + self.mask = torch.abs(time_interval) <= self.max_time_interval + + if self.anchor_handler is not None: + T_temp2cur = self.cached_anchor.new_tensor( + np.stack( + [ + x["T_global_inv"] + @ self.metas["img_metas"][i]["T_global"] + for i, x in enumerate(metas["img_metas"]) + ] + ) + ) + self.cached_anchor = self.anchor_handler.anchor_projection( + self.cached_anchor, + [T_temp2cur], + time_intervals=[-time_interval], + )[0] + + if ( + self.anchor_handler is not None + and dn_metas is not None + and batch_size == dn_metas["dn_anchor"].shape[0] + ): + num_dn_group, num_dn = dn_metas["dn_anchor"].shape[1:3] + dn_anchor = self.anchor_handler.anchor_projection( + dn_metas["dn_anchor"].flatten(1, 2), + [T_temp2cur], + time_intervals=[-time_interval], + )[0] + dn_metas["dn_anchor"] = dn_anchor.reshape( + batch_size, num_dn_group, num_dn, -1 + ) + time_interval = torch.where( + torch.logical_and(time_interval != 0, self.mask), + time_interval, + time_interval.new_tensor(self.default_time_interval), + ) + else: + self.reset() + time_interval = instance_feature.new_tensor( + [self.default_time_interval] * batch_size + ) + + return ( + instance_feature, + anchor, + self.cached_feature, + self.cached_anchor, + time_interval, + ) + + def update(self, instance_feature, anchor, confidence, + cached_feature_override=None, cached_anchor_override=None): + if self.cached_feature is None: + return instance_feature, anchor + + num_dn = 0 + if instance_feature.shape[1] > self.num_anchor: + num_dn = instance_feature.shape[1] - self.num_anchor + dn_instance_feature = instance_feature[:, -num_dn:] + dn_anchor = anchor[:, -num_dn:] + instance_feature = instance_feature[:, : self.num_anchor] + anchor = anchor[:, : self.num_anchor] + confidence = confidence[:, : self.num_anchor] + + cached_feat = cached_feature_override if cached_feature_override is not None \ + else self.cached_feature + cached_anch = cached_anchor_override if cached_anchor_override is not None \ + else self.cached_anchor + + N = self.num_anchor - self.num_temp_instances + confidence = confidence.max(dim=-1).values + _, (selected_feature, selected_anchor) = topk( + confidence, N, instance_feature, anchor + ) + selected_feature = torch.cat( + [cached_feat, selected_feature], dim=1 + ) + selected_anchor = torch.cat( + [cached_anch, selected_anchor], dim=1 + ) + instance_feature = torch.where( + self.mask[:, None, None], selected_feature, instance_feature + ) + anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor) + self.confidence = torch.where( + self.mask[:, None], + self.confidence, + self.confidence.new_tensor(0) + ) + if self.instance_id is not None: + self.instance_id = torch.where( + self.mask[:, None], + self.instance_id, + self.instance_id.new_tensor(-1), + ) + + if num_dn > 0: + instance_feature = torch.cat( + [instance_feature, dn_instance_feature], dim=1 + ) + anchor = torch.cat([anchor, dn_anchor], dim=1) + return instance_feature, anchor + + def cache( + self, + instance_feature, + anchor, + confidence, + metas=None, + feature_maps=None, + ): + if self.num_temp_instances <= 0: + return + instance_feature = instance_feature.detach() + anchor = anchor.detach() + confidence = confidence.detach() + + self.metas = metas + confidence = confidence.max(dim=-1).values.sigmoid() + if self.confidence is not None: + confidence[:, : self.num_temp_instances] = torch.maximum( + self.confidence * self.confidence_decay, + confidence[:, : self.num_temp_instances], + ) + self.temp_confidence = confidence + + ( + self.confidence, + (self.cached_feature, self.cached_anchor), + ) = topk(confidence, self.num_temp_instances, instance_feature, anchor) + + def get_instance_id(self, confidence, anchor=None, threshold=None): + confidence = confidence.max(dim=-1).values.sigmoid() + instance_id = confidence.new_full(confidence.shape, -1).long() + + if ( + self.instance_id is not None + and self.instance_id.shape[0] == instance_id.shape[0] + ): + instance_id[:, : self.instance_id.shape[1]] = self.instance_id + + mask = instance_id < 0 + if threshold is not None: + mask = mask & (confidence >= threshold) + num_new_instance = mask.sum() + new_ids = torch.arange(num_new_instance).to(instance_id) + self.prev_id + instance_id[torch.where(mask)] = new_ids + self.prev_id += num_new_instance + self.update_instance_id(instance_id, confidence) + return instance_id + + def update_instance_id(self, instance_id=None, confidence=None): + if self.temp_confidence is None: + if confidence.dim() == 3: # bs, num_anchor, num_cls + temp_conf = confidence.max(dim=-1).values + else: # bs, num_anchor + temp_conf = confidence + else: + temp_conf = self.temp_confidence + instance_id = topk(temp_conf, self.num_temp_instances, instance_id)[1][ + 0 + ] + instance_id = instance_id.squeeze(dim=-1) + self.instance_id = F.pad( + instance_id, + (0, self.num_anchor - self.num_temp_instances), + value=-1, + ) \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/map/__init__.py b/projects/mmdet3d_plugin/models/map/__init__.py new file mode 100644 index 0000000..4fa1d7d --- /dev/null +++ b/projects/mmdet3d_plugin/models/map/__init__.py @@ -0,0 +1,9 @@ +from .decoder import SparsePoint3DDecoder +from .target import SparsePoint3DTarget, HungarianLinesAssigner +from .match_cost import LinesL1Cost, MapQueriesCost +from .loss import LinesL1Loss, SparseLineLoss +from .map_blocks import ( + SparsePoint3DRefinementModule, + SparsePoint3DKeyPointsGenerator, + SparsePoint3DEncoder, +) \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/map/decoder.py b/projects/mmdet3d_plugin/models/map/decoder.py new file mode 100644 index 0000000..b9564c9 --- /dev/null +++ b/projects/mmdet3d_plugin/models/map/decoder.py @@ -0,0 +1,53 @@ +from typing import Optional, List + +import torch + +from mmdet.core.bbox.builder import BBOX_CODERS + + +@BBOX_CODERS.register_module() +class SparsePoint3DDecoder(object): + def __init__( + self, + coords_dim: int = 2, + score_threshold: Optional[float] = None, + ): + super(SparsePoint3DDecoder, self).__init__() + self.score_threshold = score_threshold + self.coords_dim = coords_dim + + def decode( + self, + cls_scores, + pts_preds, + instance_id=None, + quality=None, + output_idx=-1, + ): + bs, num_pred, num_cls = cls_scores[-1].shape + cls_scores = cls_scores[-1].sigmoid() + pts_preds = pts_preds[-1].reshape(bs, num_pred, -1, self.coords_dim) + cls_scores, indices = cls_scores.flatten(start_dim=1).topk( + num_pred, dim=1 + ) + cls_ids = indices % num_cls + if self.score_threshold is not None: + mask = cls_scores >= self.score_threshold + output = [] + for i in range(bs): + category_ids = cls_ids[i] + scores = cls_scores[i] + pts = pts_preds[i, indices[i] // num_cls] + if self.score_threshold is not None: + category_ids = category_ids[mask[i]] + scores = scores[mask[i]] + pts = pts[mask[i]] + + output.append( + { + "vectors": [vec.detach().cpu().numpy() for vec in pts], + "scores": scores.detach().cpu().numpy(), + "labels": category_ids.detach().cpu().numpy(), + } + ) + return output \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/map/loss.py b/projects/mmdet3d_plugin/models/map/loss.py new file mode 100644 index 0000000..0edd8be --- /dev/null +++ b/projects/mmdet3d_plugin/models/map/loss.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn + +from mmcv.utils import build_from_cfg +from mmdet.models.builder import LOSSES +from mmdet.models.losses import l1_loss, smooth_l1_loss + + +@LOSSES.register_module() +class LinesL1Loss(nn.Module): + + def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5): + """ + L1 loss. The same as the smooth L1 loss + Args: + reduction (str, optional): The method to reduce the loss. + Options are "none", "mean" and "sum". + loss_weight (float, optional): The weight of loss. + """ + + super().__init__() + self.reduction = reduction + self.loss_weight = loss_weight + self.beta = beta + + def forward(self, + pred, + target, + weight=None, + avg_factor=None, + reduction_override=None): + """Forward function. + Args: + pred (torch.Tensor): The prediction. + shape: [bs, ...] + target (torch.Tensor): The learning target of the prediction. + shape: [bs, ...] + weight (torch.Tensor, optional): The weight of loss for each + prediction. Defaults to None. + it's useful when the predictions are not all valid. + avg_factor (int, optional): Average factor that is used to average + the loss. Defaults to None. + reduction_override (str, optional): The reduction method used to + override the original reduction method of the loss. + Defaults to None. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + if self.beta > 0: + loss = smooth_l1_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta) + + else: + loss = l1_loss( + pred, target, weight, reduction=reduction, avg_factor=avg_factor) + + num_points = pred.shape[-1] // 2 + loss = loss / num_points + + return loss*self.loss_weight + + +@LOSSES.register_module() +class SparseLineLoss(nn.Module): + def __init__( + self, + loss_line, + num_sample=20, + roi_size=(30, 60), + ): + super().__init__() + + def build(cfg, registry): + if cfg is None: + return None + return build_from_cfg(cfg, registry) + + self.loss_line = build(loss_line, LOSSES) + self.num_sample = num_sample + self.roi_size = roi_size + + def forward( + self, + line, + line_target, + weight=None, + avg_factor=None, + prefix="", + suffix="", + **kwargs, + ): + + output = {} + line = self.normalize_line(line) + line_target = self.normalize_line(line_target) + line_loss = self.loss_line( + line, line_target, weight=weight, avg_factor=avg_factor + ) + output[f"{prefix}loss_line{suffix}"] = line_loss + + return output + + def normalize_line(self, line): + if line.shape[0] == 0: + return line + + line = line.view(line.shape[:-1] + (self.num_sample, -1)) + + origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2]) + line = line - origin + + # transform from range [0, 1] to (0, 1) + eps = 1e-5 + norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps + line = line / norm + line = line.flatten(-2, -1) + + return line diff --git a/projects/mmdet3d_plugin/models/map/map_blocks.py b/projects/mmdet3d_plugin/models/map/map_blocks.py new file mode 100644 index 0000000..1fb87bd --- /dev/null +++ b/projects/mmdet3d_plugin/models/map/map_blocks.py @@ -0,0 +1,199 @@ +from typing import Optional, List, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from mmcv.cnn import Linear, Scale, bias_init_with_prob +from mmcv.runner.base_module import Sequential, BaseModule +from mmcv.cnn import xavier_init +from mmcv.cnn.bricks.registry import ( + PLUGIN_LAYERS, + POSITIONAL_ENCODING, +) + +from ..blocks import linear_relu_ln + + +@POSITIONAL_ENCODING.register_module() +class SparsePoint3DEncoder(BaseModule): + def __init__( + self, + embed_dims: int = 256, + num_sample: int = 20, + coords_dim: int = 2, + ): + super(SparsePoint3DEncoder, self).__init__() + self.embed_dims = embed_dims + self.input_dims = num_sample * coords_dim + def embedding_layer(input_dims): + return nn.Sequential(*linear_relu_ln(embed_dims, 1, 2, input_dims)) + + self.pos_fc = embedding_layer(self.input_dims) + + def forward(self, anchor: torch.Tensor): + pos_feat = self.pos_fc(anchor) + return pos_feat + + +@PLUGIN_LAYERS.register_module() +class SparsePoint3DRefinementModule(BaseModule): + def __init__( + self, + embed_dims: int = 256, + num_sample: int = 20, + coords_dim: int = 2, + num_cls: int = 3, + with_cls_branch: bool = True, + ): + super(SparsePoint3DRefinementModule, self).__init__() + self.embed_dims = embed_dims + self.num_sample = num_sample + self.output_dim = num_sample * coords_dim + self.num_cls = num_cls + + self.layers = nn.Sequential( + *linear_relu_ln(embed_dims, 2, 2), + Linear(self.embed_dims, self.output_dim), + Scale([1.0] * self.output_dim), + ) + + self.with_cls_branch = with_cls_branch + if with_cls_branch: + self.cls_layers = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 2), + Linear(self.embed_dims, self.num_cls), + ) + + def init_weight(self): + if self.with_cls_branch: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.cls_layers[-1].bias, bias_init) + + def forward( + self, + instance_feature: torch.Tensor, + anchor: torch.Tensor, + anchor_embed: torch.Tensor, + time_interval: torch.Tensor = 1.0, + return_cls=True, + ): + output = self.layers(instance_feature + anchor_embed) + output = output + anchor + if return_cls: + assert self.with_cls_branch, "Without classification layers !!!" + cls = self.cls_layers(instance_feature) ## NOTE anchor embed? + else: + cls = None + qt = None + return output, cls, qt, None + + +@PLUGIN_LAYERS.register_module() +class SparsePoint3DKeyPointsGenerator(BaseModule): + def __init__( + self, + embed_dims: int = 256, + num_sample: int = 20, + num_learnable_pts: int = 0, + fix_height: Tuple = (0,), + ground_height: int = 0, + ): + super(SparsePoint3DKeyPointsGenerator, self).__init__() + self.embed_dims = embed_dims + self.num_sample = num_sample + self.num_learnable_pts = num_learnable_pts + self.num_pts = num_sample * len(fix_height) * num_learnable_pts + if self.num_learnable_pts > 0: + self.learnable_fc = Linear(self.embed_dims, self.num_pts * 2) + + self.fix_height = np.array(fix_height) + self.ground_height = ground_height + + def init_weight(self): + if self.num_learnable_pts > 0: + xavier_init(self.learnable_fc, distribution="uniform", bias=0.0) + + def forward( + self, + anchor, + instance_feature=None, + T_cur2temp_list=None, + cur_timestamp=None, + temp_timestamps=None, + ): + assert self.num_learnable_pts > 0, 'No learnable pts' + bs, num_anchor, _ = anchor.shape + key_points = anchor.view(bs, num_anchor, self.num_sample, -1) + offset = ( + self.learnable_fc(instance_feature) + .reshape(bs, num_anchor, self.num_sample, len(self.fix_height), self.num_learnable_pts, 2) + ) + key_points = offset + key_points[..., None, None, :] + key_points = torch.cat( + [ + key_points, + key_points.new_full(key_points.shape[:-1]+(1,), fill_value=self.ground_height), + ], + dim=-1, + ) + fix_height = key_points.new_tensor(self.fix_height) + height_offset = key_points.new_zeros([len(fix_height), 2]) + height_offset = torch.cat([height_offset, fix_height[:,None]], dim=-1) + key_points = key_points + height_offset[None, None, None, :, None] + key_points = key_points.flatten(2, 4) + if ( + cur_timestamp is None + or temp_timestamps is None + or T_cur2temp_list is None + or len(temp_timestamps) == 0 + ): + return key_points + + temp_key_points_list = [] + for i, t_time in enumerate(temp_timestamps): + temp_key_points = key_points + T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype) + temp_key_points = ( + T_cur2temp[:, None, None, :3] + @ torch.cat( + [ + temp_key_points, + torch.ones_like(temp_key_points[..., :1]), + ], + dim=-1, + ).unsqueeze(-1) + ) + temp_key_points = temp_key_points.squeeze(-1) + temp_key_points_list.append(temp_key_points) + return key_points, temp_key_points_list + + # @staticmethod + def anchor_projection( + self, + anchor, + T_src2dst_list, + src_timestamp=None, + dst_timestamps=None, + time_intervals=None, + ): + dst_anchors = [] + for i in range(len(T_src2dst_list)): + dst_anchor = anchor.clone() + bs, num_anchor, _ = anchor.shape + dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(1, 2) + T_src2dst = torch.unsqueeze( + T_src2dst_list[i].to(dtype=anchor.dtype), dim=1 + ) + + dst_anchor = ( + torch.matmul( + T_src2dst[..., :2, :2], dst_anchor[..., None] + ).squeeze(dim=-1) + + T_src2dst[..., :2, 3] + ) + + dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(2, 3) + dst_anchors.append(dst_anchor) + return dst_anchors \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/map/match_cost.py b/projects/mmdet3d_plugin/models/map/match_cost.py new file mode 100644 index 0000000..3c19a71 --- /dev/null +++ b/projects/mmdet3d_plugin/models/map/match_cost.py @@ -0,0 +1,104 @@ +import torch +from mmdet.core.bbox.match_costs.builder import MATCH_COST +from mmdet.core.bbox.match_costs import build_match_cost +from torch.nn.functional import smooth_l1_loss + + +@MATCH_COST.register_module() +class LinesL1Cost(object): + """LinesL1Cost. + Args: + weight (int | float, optional): loss_weight + """ + + def __init__(self, weight=1.0, beta=0.0, permute=False): + self.weight = weight + self.permute = permute + self.beta = beta + + def __call__(self, lines_pred, gt_lines, **kwargs): + """ + Args: + lines_pred (Tensor): predicted normalized lines: + [num_query, 2*num_points] + gt_lines (Tensor): Ground truth lines + [num_gt, 2*num_points] or [num_gt, num_permute, 2*num_points] + Returns: + torch.Tensor: reg_cost value with weight + shape [num_pred, num_gt] + """ + if self.permute: + assert len(gt_lines.shape) == 3 + else: + assert len(gt_lines.shape) == 2 + + num_pred, num_gt = len(lines_pred), len(gt_lines) + if self.permute: + # permute-invarint labels + gt_lines = gt_lines.flatten(0, 1) # (num_gt*num_permute, 2*num_pts) + + num_pts = lines_pred.shape[-1]//2 + + if self.beta > 0: + lines_pred = lines_pred.unsqueeze(1).repeat(1, len(gt_lines), 1) + gt_lines = gt_lines.unsqueeze(0).repeat(num_pred, 1, 1) + dist_mat = smooth_l1_loss(lines_pred, gt_lines, reduction='none', beta=self.beta).sum(-1) + + else: + dist_mat = torch.cdist(lines_pred, gt_lines, p=1) + + dist_mat = dist_mat / num_pts + + if self.permute: + # dist_mat: (num_pred, num_gt*num_permute) + dist_mat = dist_mat.view(num_pred, num_gt, -1) # (num_pred, num_gt, num_permute) + dist_mat, gt_permute_index = torch.min(dist_mat, 2) + return dist_mat * self.weight, gt_permute_index + + return dist_mat * self.weight + + +@MATCH_COST.register_module() +class MapQueriesCost(object): + + def __init__(self, cls_cost, reg_cost, iou_cost=None): + + self.cls_cost = build_match_cost(cls_cost) + self.reg_cost = build_match_cost(reg_cost) + + self.iou_cost = None + if iou_cost is not None: + self.iou_cost = build_match_cost(iou_cost) + + def __call__(self, preds: dict, gts: dict, ignore_cls_cost: bool): + + # classification and bboxcost. + cls_cost = self.cls_cost(preds['scores'], gts['labels']) + + # regression cost + regkwargs = {} + if 'masks' in preds and 'masks' in gts: + assert isinstance(self.reg_cost, DynamicLinesCost), ' Issues!!' + regkwargs = { + 'masks_pred': preds['masks'], + 'masks_gt': gts['masks'], + } + + reg_cost = self.reg_cost(preds['lines'], gts['lines'], **regkwargs) + if self.reg_cost.permute: + reg_cost, gt_permute_idx = reg_cost + + # weighted sum of above three costs + if ignore_cls_cost: + cost = reg_cost + else: + cost = cls_cost + reg_cost + + # Iou + if self.iou_cost is not None: + iou_cost = self.iou_cost(preds['lines'],gts['lines']) + cost += iou_cost + + if self.reg_cost.permute: + return cost, gt_permute_idx + return cost diff --git a/projects/mmdet3d_plugin/models/map/target.py b/projects/mmdet3d_plugin/models/map/target.py new file mode 100644 index 0000000..6695fce --- /dev/null +++ b/projects/mmdet3d_plugin/models/map/target.py @@ -0,0 +1,167 @@ +import torch +import numpy as np +import torch.nn.functional as F +from scipy.optimize import linear_sum_assignment + +from mmdet.core.bbox.builder import (BBOX_SAMPLERS, BBOX_ASSIGNERS) +from mmdet.core.bbox.match_costs import build_match_cost +from mmdet.core import (build_assigner, build_sampler) +from mmdet.core.bbox.assigners import (AssignResult, BaseAssigner) + +from ..base_target import BaseTargetWithDenoising + + +@BBOX_SAMPLERS.register_module() +class SparsePoint3DTarget(BaseTargetWithDenoising): + def __init__( + self, + assigner=None, + num_dn_groups=0, + dn_noise_scale=0.5, + max_dn_gt=32, + add_neg_dn=True, + num_temp_dn_groups=0, + num_cls=3, + num_sample=20, + roi_size=(30, 60), + ): + super(SparsePoint3DTarget, self).__init__( + num_dn_groups, num_temp_dn_groups + ) + self.assigner = build_assigner(assigner) + self.dn_noise_scale = dn_noise_scale + self.max_dn_gt = max_dn_gt + self.add_neg_dn = add_neg_dn + + self.num_cls = num_cls + self.num_sample = num_sample + self.roi_size = roi_size + + def sample( + self, + cls_preds, + pts_preds, + cls_targets, + pts_targets, + ): + pts_targets = [x.flatten(2, 3) if len(x.shape)==4 else x for x in pts_targets] + indices = [] + for(cls_pred, pts_pred, cls_target, pts_target) in zip( + cls_preds, pts_preds, cls_targets, pts_targets + ): + # normalize to (0, 1) + pts_pred = self.normalize_line(pts_pred) + pts_target = self.normalize_line(pts_target) + preds=dict(lines=pts_pred, scores=cls_pred) + gts=dict(lines=pts_target, labels=cls_target) + indice = self.assigner.assign(preds, gts) + indices.append(indice) + + bs, num_pred, num_cls = cls_preds.shape + output_cls_target = cls_targets[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls + output_box_target = pts_preds.new_zeros(pts_preds.shape) + output_reg_weights = pts_preds.new_zeros(pts_preds.shape) + for i, (pred_idx, target_idx, gt_permute_index) in enumerate(indices): + if len(cls_targets[i]) == 0: + continue + permute_idx = gt_permute_index[pred_idx, target_idx] + output_cls_target[i, pred_idx] = cls_targets[i][target_idx] + output_box_target[i, pred_idx] = pts_targets[i][target_idx, permute_idx] + output_reg_weights[i, pred_idx] = 1 + + return output_cls_target, output_box_target, output_reg_weights + + def normalize_line(self, line): + if line.shape[0] == 0: + return line + + line = line.view(line.shape[:-1] + (self.num_sample, -1)) + + origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2]) + line = line - origin + + # transform from range [0, 1] to (0, 1) + eps = 1e-5 + norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps + line = line / norm + line = line.flatten(-2, -1) + + return line + + +@BBOX_ASSIGNERS.register_module() +class HungarianLinesAssigner(BaseAssigner): + """ + Computes one-to-one matching between predictions and ground truth. + This class computes an assignment between the targets and the predictions + based on the costs. The costs are weighted sum of three components: + classification cost and regression L1 cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched + are treated as backgrounds. Thus each query prediction will be assigned + with `0` or a positive integer indicating the ground truth index: + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + Args: + cls_weight (int | float, optional): The scale factor for classification + cost. Default 1.0. + bbox_weight (int | float, optional): The scale factor for regression + L1 cost. Default 1.0. + """ + + def __init__(self, cost=dict, **kwargs): + self.cost = build_match_cost(cost) + + def assign(self, + preds: dict, + gts: dict, + ignore_cls_cost=False, + gt_bboxes_ignore=None, + eps=1e-7): + """ + Computes one-to-one matching based on the weighted costs. + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + Args: + lines_pred (Tensor): predicted normalized lines: + [num_query, num_points, 2] + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + + lines_gt (Tensor): Ground truth lines + [num_gt, num_points, 2]. + labels_gt (Tensor): Label of `gt_bboxes`, shape (num_gt,). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_bboxes_ignore is None, \ + 'Only case when gt_bboxes_ignore is None is supported.' + + num_gts, num_lines = gts['lines'].size(0), preds['lines'].size(0) + if num_gts == 0 or num_lines == 0: + return None, None, None + + # compute the weighted costs + gt_permute_idx = None # (num_preds, num_gts) + if self.cost.reg_cost.permute: + cost, gt_permute_idx = self.cost(preds, gts, ignore_cls_cost) + else: + cost = self.cost(preds, gts, ignore_cls_cost) + + # do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu().numpy() + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + return matched_row_inds, matched_col_inds, gt_permute_idx \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/motion/__init__.py b/projects/mmdet3d_plugin/models/motion/__init__.py new file mode 100644 index 0000000..11b2f42 --- /dev/null +++ b/projects/mmdet3d_plugin/models/motion/__init__.py @@ -0,0 +1,6 @@ +from .motion_planning_head import MotionPlanningHead +from .kinematic_motion_planning_head import KinematicMotionPlanningHead +from .motion_blocks import MotionPlanningRefinementModule +from .instance_queue import InstanceQueue +from .target import MotionTarget, PlanningTarget +from .decoder import SparseBox3DMotionDecoder, HierarchicalPlanningDecoder diff --git a/projects/mmdet3d_plugin/models/motion/decoder.py b/projects/mmdet3d_plugin/models/motion/decoder.py new file mode 100644 index 0000000..fa673cf --- /dev/null +++ b/projects/mmdet3d_plugin/models/motion/decoder.py @@ -0,0 +1,329 @@ +from typing import Optional + +import numpy as np +import torch + +from mmdet.core.bbox.builder import BBOX_CODERS + +from projects.mmdet3d_plugin.core.box3d import * +from projects.mmdet3d_plugin.models.detection3d.decoder import * +from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners + + +@BBOX_CODERS.register_module() +class SparseBox3DMotionDecoder(SparseBox3DDecoder): + def __init__(self): + super(SparseBox3DMotionDecoder, self).__init__() + + def decode( + self, + cls_scores, + box_preds, + instance_id=None, + quality=None, + motion_output=None, + output_idx=-1, + ): + squeeze_cls = instance_id is not None + + cls_scores = cls_scores[output_idx].sigmoid() + + if squeeze_cls: + cls_scores, cls_ids = cls_scores.max(dim=-1) + cls_scores = cls_scores.unsqueeze(dim=-1) + + box_preds = box_preds[output_idx] + bs, num_pred, num_cls = cls_scores.shape + cls_scores, indices = cls_scores.flatten(start_dim=1).topk( + self.num_output, dim=1, sorted=self.sorted + ) + if not squeeze_cls: + cls_ids = indices % num_cls + if self.score_threshold is not None: + mask = cls_scores >= self.score_threshold + + if quality[output_idx] is None: + quality = None + if quality is not None: + centerness = quality[output_idx][..., CNS] + centerness = torch.gather(centerness, 1, indices // num_cls) + cls_scores_origin = cls_scores.clone() + cls_scores *= centerness.sigmoid() + cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True) + if not squeeze_cls: + cls_ids = torch.gather(cls_ids, 1, idx) + if self.score_threshold is not None: + mask = torch.gather(mask, 1, idx) + indices = torch.gather(indices, 1, idx) + + output = [] + anchor_queue = motion_output["anchor_queue"] + anchor_queue = torch.stack(anchor_queue, dim=2) + period = motion_output["period"] + + for i in range(bs): + category_ids = cls_ids[i] + if squeeze_cls: + category_ids = category_ids[indices[i]] + scores = cls_scores[i] + box = box_preds[i, indices[i] // num_cls] + if self.score_threshold is not None: + category_ids = category_ids[mask[i]] + scores = scores[mask[i]] + box = box[mask[i]] + if quality is not None: + scores_origin = cls_scores_origin[i] + if self.score_threshold is not None: + scores_origin = scores_origin[mask[i]] + + box = decode_box(box) + trajs = motion_output["prediction"][-1] + traj_cls = motion_output["classification"][-1].sigmoid() + traj = trajs[i, indices[i] // num_cls] + traj_cls = traj_cls[i, indices[i] // num_cls] + if self.score_threshold is not None: + traj = traj[mask[i]] + traj_cls = traj_cls[mask[i]] + traj = traj.cumsum(dim=-2) + box[:, None, None, :2] + output.append( + { + "trajs_3d": traj.cpu(), + "trajs_score": traj_cls.cpu() + } + ) + + temp_anchor = anchor_queue[i, indices[i] // num_cls] + temp_period = period[i, indices[i] // num_cls] + if self.score_threshold is not None: + temp_anchor = temp_anchor[mask[i]] + temp_period = temp_period[mask[i]] + num_pred, queue_len = temp_anchor.shape[:2] + temp_anchor = temp_anchor.flatten(0, 1) + temp_anchor = decode_box(temp_anchor) + temp_anchor = temp_anchor.reshape([num_pred, queue_len, box.shape[-1]]) + output[-1]['anchor_queue'] = temp_anchor.cpu() + output[-1]['period'] = temp_period.cpu() + + return output + + +@BBOX_CODERS.register_module() +class HierarchicalPlanningDecoder(object): + def __init__( + self, + ego_fut_ts, + ego_fut_mode, + use_rescore=False, + ): + super(HierarchicalPlanningDecoder, self).__init__() + self.ego_fut_ts = ego_fut_ts + self.ego_fut_mode = ego_fut_mode + self.use_rescore = use_rescore + + def decode( + self, + det_output, + motion_output, + planning_output, + data, + ): + classification = planning_output['classification'][-1] + prediction = planning_output['prediction'][-1] + bs = classification.shape[0] + classification = classification.reshape(bs, 3, self.ego_fut_mode) + prediction = prediction.reshape(bs, 3, self.ego_fut_mode, self.ego_fut_ts, 2).cumsum(dim=-2) + classification, final_planning = self.select(det_output, motion_output, classification, prediction, data) + anchor_queue = planning_output["anchor_queue"] + anchor_queue = torch.stack(anchor_queue, dim=2) + period = planning_output["period"] + output = [] + for i, (cls, pred) in enumerate(zip(classification, prediction)): + output.append( + { + "planning_score": cls.sigmoid().cpu(), + "planning": pred.cpu(), + "final_planning": final_planning[i].cpu(), + "ego_period": period[i].cpu(), + "ego_anchor_queue": decode_box(anchor_queue[i]).cpu(), + } + ) + + return output + + def select( + self, + det_output, + motion_output, + plan_cls, + plan_reg, + data, + ): + det_classification = det_output["classification"][-1].sigmoid() + det_anchors = det_output["prediction"][-1] + det_confidence = det_classification.max(dim=-1).values + motion_cls = motion_output["classification"][-1].sigmoid() + motion_reg = motion_output["prediction"][-1] + + # cmd select + bs = motion_cls.shape[0] + bs_indices = torch.arange(bs, device=motion_cls.device) + cmd = data['gt_ego_fut_cmd'].argmax(dim=-1) + plan_cls_full = plan_cls.detach().clone() + plan_cls = plan_cls[bs_indices, cmd] + plan_reg = plan_reg[bs_indices, cmd] + + # rescore + if self.use_rescore: + plan_cls = self.rescore( + plan_cls, + plan_reg, + motion_cls, + motion_reg, + det_anchors, + det_confidence, + ) + plan_cls_full[bs_indices, cmd] = plan_cls + mode_idx = plan_cls.argmax(dim=-1) + final_planning = plan_reg[bs_indices, mode_idx] + return plan_cls_full, final_planning + + def rescore( + self, + plan_cls, + plan_reg, + motion_cls, + motion_reg, + det_anchors, + det_confidence, + score_thresh=0.5, + static_dis_thresh=0.5, + dim_scale=1.1, + num_motion_mode=1, + offset=0.5, + ): + + def cat_with_zero(traj): + zeros = traj.new_zeros(traj.shape[:-2] + (1, 2)) + traj_cat = torch.cat([zeros, traj], dim=-2) + return traj_cat + + def get_yaw(traj, start_yaw=np.pi/2): + yaw = traj.new_zeros(traj.shape[:-1]) + yaw[..., 1:-1] = torch.atan2( + traj[..., 2:, 1] - traj[..., :-2, 1], + traj[..., 2:, 0] - traj[..., :-2, 0], + ) + yaw[..., -1] = torch.atan2( + traj[..., -1, 1] - traj[..., -2, 1], + traj[..., -1, 0] - traj[..., -2, 0], + ) + yaw[..., 0] = start_yaw + # for static object, estimated future yaw would be unstable + start = traj[..., 0, :] + end = traj[..., -1, :] + dist = torch.linalg.norm(end - start, dim=-1) + mask = dist < static_dis_thresh + start_yaw = yaw[..., 0].unsqueeze(-1) + yaw = torch.where( + mask.unsqueeze(-1), + start_yaw, + yaw, + ) + return yaw.unsqueeze(-1) + + ## ego + bs = plan_reg.shape[0] + plan_reg_cat = cat_with_zero(plan_reg) + ego_box = det_anchors.new_zeros(bs, self.ego_fut_mode, self.ego_fut_ts + 1, 7) + ego_box[..., [X, Y]] = plan_reg_cat + ego_box[..., [W, L, H]] = ego_box.new_tensor([4.08, 1.73, 1.56]) * dim_scale + ego_box[..., [YAW]] = get_yaw(plan_reg_cat) + + ## motion + motion_reg = motion_reg[..., :self.ego_fut_ts, :].cumsum(-2) + motion_reg = cat_with_zero(motion_reg) + det_anchors[:, :, None, None, :2] + _, motion_mode_idx = torch.topk(motion_cls, num_motion_mode, dim=-1) + motion_mode_idx = motion_mode_idx[..., None, None].repeat(1, 1, 1, self.ego_fut_ts + 1, 2) + motion_reg = torch.gather(motion_reg, 2, motion_mode_idx) + + motion_box = motion_reg.new_zeros(motion_reg.shape[:-1] + (7,)) + motion_box[..., [X, Y]] = motion_reg + motion_box[..., [W, L, H]] = det_anchors[..., None, None, [W, L, H]].exp() + box_yaw = torch.atan2( + det_anchors[..., SIN_YAW], + det_anchors[..., COS_YAW], + ) + motion_box[..., [YAW]] = get_yaw(motion_reg, box_yaw.unsqueeze(-1)) + + filter_mask = det_confidence < score_thresh + motion_box[filter_mask] = 1e6 + + ego_box = ego_box[..., 1:, :] + motion_box = motion_box[..., 1:, :] + + bs, num_ego_mode, ts, _ = ego_box.shape + bs, num_anchor, num_motion_mode, ts, _ = motion_box.shape + ego_box = ego_box[:, None, None].repeat(1, num_anchor, num_motion_mode, 1, 1, 1).flatten(0, -2) + motion_box = motion_box.unsqueeze(3).repeat(1, 1, 1, num_ego_mode, 1, 1).flatten(0, -2) + + ego_box[0] += offset * torch.cos(ego_box[6]) + ego_box[1] += offset * torch.sin(ego_box[6]) + col = check_collision(ego_box, motion_box) + col = col.reshape(bs, num_anchor, num_motion_mode, num_ego_mode, ts).permute(0, 3, 1, 2, 4) + col = col.flatten(2, -1).any(dim=-1) + all_col = col.all(dim=-1) + col[all_col] = False # for case that all modes collide, no need to rescore + score_offset = col.float() * -999 + plan_cls = plan_cls + score_offset + return plan_cls + + +def check_collision(boxes1, boxes2): + ''' + A rough check for collision detection: + check if any corner point of boxes1 is inside boxes2 and vice versa. + + boxes1: tensor with shape [N, 7], [x, y, z, w, l, h, yaw] + boxes2: tensor with shape [N, 7] + ''' + col_1 = corners_in_box(boxes1.clone(), boxes2.clone()) + col_2 = corners_in_box(boxes2.clone(), boxes1.clone()) + collision = torch.logical_or(col_1, col_2) + + return collision + +def corners_in_box(boxes1, boxes2): + if boxes1.shape[0] == 0 or boxes2.shape[0] == 0: + return False + + boxes1_yaw = boxes1[:, 6].clone() + boxes1_loc = boxes1[:, :3].clone() + cos_yaw = torch.cos(-boxes1_yaw) + sin_yaw = torch.sin(-boxes1_yaw) + rot_mat_T = torch.stack( + [ + torch.stack([cos_yaw, sin_yaw]), + torch.stack([-sin_yaw, cos_yaw]), + ] + ) + # translate and rotate boxes + boxes1[:, :3] = boxes1[:, :3] - boxes1_loc + boxes1[:, :2] = torch.einsum('ij,jki->ik', boxes1[:, :2], rot_mat_T) + boxes1[:, 6] = boxes1[:, 6] - boxes1_yaw + + boxes2[:, :3] = boxes2[:, :3] - boxes1_loc + boxes2[:, :2] = torch.einsum('ij,jki->ik', boxes2[:, :2], rot_mat_T) + boxes2[:, 6] = boxes2[:, 6] - boxes1_yaw + + corners_box2 = box3d_to_corners(boxes2)[:, [0, 3, 7, 4], :2] + corners_box2 = torch.from_numpy(corners_box2).to(boxes2.device) + H = boxes1[:, [3]] + W = boxes1[:, [4]] + + collision = torch.logical_and( + torch.logical_and(corners_box2[..., 0] <= H / 2, corners_box2[..., 0] >= -H / 2), + torch.logical_and(corners_box2[..., 1] <= W / 2, corners_box2[..., 1] >= -W / 2), + ) + collision = collision.any(dim=-1) + + return collision \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/motion/instance_queue.py b/projects/mmdet3d_plugin/models/motion/instance_queue.py new file mode 100644 index 0000000..1905d77 --- /dev/null +++ b/projects/mmdet3d_plugin/models/motion/instance_queue.py @@ -0,0 +1,213 @@ +import copy +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +from mmcv.utils import build_from_cfg +from mmcv.cnn.bricks.registry import PLUGIN_LAYERS + +from projects.mmdet3d_plugin.ops import feature_maps_format +from projects.mmdet3d_plugin.core.box3d import * + + +@PLUGIN_LAYERS.register_module() +class InstanceQueue(nn.Module): + def __init__( + self, + embed_dims, + queue_length=0, + tracking_threshold=0, + feature_map_scale=None, + ): + super(InstanceQueue, self).__init__() + self.embed_dims = embed_dims + self.queue_length = queue_length + self.tracking_threshold = tracking_threshold + + kernel_size = tuple([int(x / 2) for x in feature_map_scale]) + self.ego_feature_encoder = nn.Sequential( + nn.Conv2d(embed_dims, embed_dims, 3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(embed_dims), + nn.Conv2d(embed_dims, embed_dims, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(embed_dims), + nn.ReLU(), + nn.AvgPool2d(kernel_size), + ) + self.ego_anchor = nn.Parameter( + torch.tensor([[0, 0.5, -1.84 + 1.56/2, np.log(4.08), np.log(1.73), np.log(1.56), 1, 0, 0, 0, 0],], dtype=torch.float32), + requires_grad=False, + ) + + self.reset() + + def reset(self): + self.metas = None + self.prev_instance_id = None + self.prev_confidence = None + self.period = None + self.instance_feature_queue = [] + self.anchor_queue = [] + self.prev_ego_status = None + self.ego_period = None + self.ego_feature_queue = [] + self.ego_anchor_queue = [] + + def get( + self, + det_output, + feature_maps, + metas, + batch_size, + mask, + anchor_handler, + ): + if ( + self.period is not None + and batch_size == self.period.shape[0] + ): + if anchor_handler is not None: + T_temp2cur = feature_maps[0].new_tensor( + np.stack( + [ + x["T_global_inv"] + @ self.metas["img_metas"][i]["T_global"] + for i, x in enumerate(metas["img_metas"]) + ] + ) + ) + for i in range(len(self.anchor_queue)): + temp_anchor = self.anchor_queue[i] + temp_anchor = anchor_handler.anchor_projection( + temp_anchor, + [T_temp2cur], + )[0] + self.anchor_queue[i] = temp_anchor + for i in range(len(self.ego_anchor_queue)): + temp_anchor = self.ego_anchor_queue[i] + temp_anchor = anchor_handler.anchor_projection( + temp_anchor, + [T_temp2cur], + )[0] + self.ego_anchor_queue[i] = temp_anchor + else: + self.reset() + + self.prepare_motion(det_output, mask) + ego_feature, ego_anchor = self.prepare_planning(feature_maps, mask, batch_size) + + # temporal + temp_instance_feature = torch.stack(self.instance_feature_queue, dim=2) + temp_anchor = torch.stack(self.anchor_queue, dim=2) + temp_ego_feature = torch.stack(self.ego_feature_queue, dim=2) + temp_ego_anchor = torch.stack(self.ego_anchor_queue, dim=2) + + period = torch.cat([self.period, self.ego_period], dim=1) + temp_instance_feature = torch.cat([temp_instance_feature, temp_ego_feature], dim=1) + temp_anchor = torch.cat([temp_anchor, temp_ego_anchor], dim=1) + num_agent = temp_anchor.shape[1] + + temp_mask = torch.arange(len(self.anchor_queue), 0, -1, device=temp_anchor.device) + temp_mask = temp_mask[None, None].repeat((batch_size, num_agent, 1)) + temp_mask = torch.gt(temp_mask, period[..., None]) + + return ego_feature, ego_anchor, temp_instance_feature, temp_anchor, temp_mask + + def prepare_motion( + self, + det_output, + mask, + ): + instance_feature = det_output["instance_feature"] + det_anchors = det_output["prediction"][-1] + + if self.period == None: + self.period = instance_feature.new_zeros(instance_feature.shape[:2]).long() + else: + instance_id = det_output['instance_id'] + prev_instance_id = self.prev_instance_id + match = instance_id[..., None] == prev_instance_id[:, None] + if self.tracking_threshold > 0: + temp_mask = self.prev_confidence > self.tracking_threshold + match = match * temp_mask.unsqueeze(1) + + for i in range(len(self.instance_feature_queue)): + temp_feature = self.instance_feature_queue[i] + temp_feature = ( + match[..., None] * temp_feature[:, None] + ).sum(dim=2) + self.instance_feature_queue[i] = temp_feature + + temp_anchor = self.anchor_queue[i] + temp_anchor = ( + match[..., None] * temp_anchor[:, None] + ).sum(dim=2) + self.anchor_queue[i] = temp_anchor + + self.period = ( + match * self.period[:, None] + ).sum(dim=2) + + self.instance_feature_queue.append(instance_feature.detach()) + self.anchor_queue.append(det_anchors.detach()) + self.period += 1 + + if len(self.instance_feature_queue) > self.queue_length: + self.instance_feature_queue.pop(0) + self.anchor_queue.pop(0) + self.period = torch.clip(self.period, 0, self.queue_length) + + def prepare_planning( + self, + feature_maps, + mask, + batch_size, + ): + ## ego instance init + feature_maps_inv = feature_maps_format(feature_maps, inverse=True) + feature_map = feature_maps_inv[0][-1][:, 0] + ego_feature = self.ego_feature_encoder(feature_map) + ego_feature = ego_feature.unsqueeze(1).squeeze(-1).squeeze(-1) + + ego_anchor = torch.tile( + self.ego_anchor[None], (batch_size, 1, 1) + ) + if self.prev_ego_status is not None: + prev_ego_status = torch.where( + mask[:, None, None], + self.prev_ego_status, + self.prev_ego_status.new_tensor(0), + ) + ego_anchor[..., VY] = prev_ego_status[..., 6] + + if self.ego_period == None: + self.ego_period = ego_feature.new_zeros((batch_size, 1)).long() + else: + self.ego_period = torch.where( + mask[:, None], + self.ego_period, + self.ego_period.new_tensor(0), + ) + + self.ego_feature_queue.append(ego_feature.detach()) + self.ego_anchor_queue.append(ego_anchor.detach()) + self.ego_period += 1 + + if len(self.ego_feature_queue) > self.queue_length: + self.ego_feature_queue.pop(0) + self.ego_anchor_queue.pop(0) + self.ego_period = torch.clip(self.ego_period, 0, self.queue_length) + + return ego_feature, ego_anchor + + def cache_motion(self, instance_feature, det_output, metas): + det_classification = det_output["classification"][-1].sigmoid() + det_confidence = det_classification.max(dim=-1).values + instance_id = det_output['instance_id'] + self.metas = metas + self.prev_confidence = det_confidence.detach() + self.prev_instance_id = instance_id + + def cache_planning(self, ego_feature, ego_status): + self.prev_ego_status = ego_status.detach() + self.ego_feature_queue[-1] = ego_feature.detach() diff --git a/projects/mmdet3d_plugin/models/motion/kinematic_motion_planning_head.py b/projects/mmdet3d_plugin/models/motion/kinematic_motion_planning_head.py new file mode 100644 index 0000000..88f0bf2 --- /dev/null +++ b/projects/mmdet3d_plugin/models/motion/kinematic_motion_planning_head.py @@ -0,0 +1,409 @@ +import torch + +from mmcv.runner import BaseModule, force_fp32 +from mmcv.utils import build_from_cfg +from mmcv.cnn.bricks.registry import PLUGIN_LAYERS +from mmdet.core import reduce_mean +from mmdet.core.bbox.builder import BBOX_SAMPLERS, BBOX_CODERS +from mmdet.models import HEADS, build_loss + +from projects.mmdet3d_plugin.core.box3d import VX, VY + + +@HEADS.register_module() +class KinematicMotionPlanningHead(BaseModule): + """Heuristic kinematic baseline for motion and ego-planning prediction. + + Implements the CTRA (Constant Turn Rate and Acceleration) kinematic model. + Special cases are selected via config flags: + + use_acceleration use_turn_rate Model + ────────────────────────────────────── + False False CV (Constant Velocity) + True False CA (Constant Acceleration) + False True CVTR (Constant Velocity + Turn Rate) + True True CTRA (Constant Turn Rate + Acceleration) + + No prediction parameters are trained. Trajectories are derived entirely + from the GT box state stored in ``det_output["prediction"][-1]`` and, when + acceleration/turn-rate estimation is enabled, the previous-frame anchor + history in ``instance_queue.anchor_queue``. + + CTRA integration + ---------------- + Given scalar speed ``s₀``, heading ``θ₀``, turn rate ``ω`` and longitudinal + acceleration ``a`` (all estimated at the current frame), each future + timestep delta is computed via midpoint integration:: + + s_mid(t) = s₀ + a · (t + ½) · dt (clipped to ≥ 0) + θ_mid(t) = θ₀ + ω · (t + ½) · dt + Δx_t = s_mid(t) · cos(θ_mid(t)) · dt + Δy_t = s_mid(t) · sin(θ_mid(t)) · dt + + where ``t`` is zero-indexed over ``fut_ts`` steps. + + When ``anchor_queue`` is ``None`` (first frame of a sequence), acceleration + and turn-rate estimates are unavailable and the model falls back to CV. + + Interface + --------- + Signature-compatible with ``MotionPlanningHead``; swap via config + ``motion_plan_head.type``. + + Args: + fut_ts: Agent future timesteps. + fut_mode: Number of trajectory hypothesis modes. + ego_fut_ts: Ego future timesteps. + ego_fut_mode: Number of ego hypothesis modes. + dt: Seconds per timestep (0.5 s for nuScenes). + use_acceleration: Estimate longitudinal acceleration from the + velocity delta between current and previous frame. + use_turn_rate: Estimate yaw rate from the heading delta between + current and previous frame. + instance_queue: Config for InstanceQueue (anchor history tracking). + motion_sampler / motion_loss_{cls,reg}: Motion loss modules. + planning_sampler / plan_loss_{cls,reg,status}: Planning loss modules. + motion_decoder / planning_decoder: Decoder configs. + num_det / num_map: API compatibility (unused). + **kwargs: Absorbs unused ``MotionPlanningHead`` params so that + the same config dict works with just a type change. + """ + + def __init__( + self, + fut_ts=12, + fut_mode=6, + ego_fut_ts=6, + ego_fut_mode=6, + dt=0.5, + use_acceleration=False, + use_turn_rate=False, + instance_queue=None, + motion_sampler=None, + motion_loss_cls=None, + motion_loss_reg=None, + planning_sampler=None, + plan_loss_cls=None, + plan_loss_reg=None, + plan_loss_status=None, + motion_decoder=None, + planning_decoder=None, + num_det=50, + num_map=10, + init_cfg=None, + ): + super().__init__(init_cfg) + self.fut_ts = fut_ts + self.fut_mode = fut_mode + self.ego_fut_ts = ego_fut_ts + self.ego_fut_mode = ego_fut_mode + self.dt = dt + self.use_acceleration = use_acceleration + self.use_turn_rate = use_turn_rate + self.num_det = num_det + self.num_map = num_map + + def _build(cfg, registry): + return build_from_cfg(cfg, registry) if cfg is not None else None + + self.instance_queue = _build(instance_queue, PLUGIN_LAYERS) + if self.instance_queue is not None: + # ego_feature_encoder is unused in kinematic mode; freeze to avoid + # DDP "unused parameter" errors. + for p in self.instance_queue.ego_feature_encoder.parameters(): + p.requires_grad_(False) + + self.motion_sampler = _build(motion_sampler, BBOX_SAMPLERS) + self.planning_sampler = _build(planning_sampler, BBOX_SAMPLERS) + self.motion_decoder = _build(motion_decoder, BBOX_CODERS) + self.planning_decoder = _build(planning_decoder, BBOX_CODERS) + + self.motion_loss_cls = build_loss(motion_loss_cls) + self.motion_loss_reg = build_loss(motion_loss_reg) + self.plan_loss_cls = build_loss(plan_loss_cls) + self.plan_loss_reg = build_loss(plan_loss_reg) + self.plan_loss_status = build_loss(plan_loss_status) + + def init_weights(self): + if self.instance_queue is not None: + for m in self.instance_queue.modules(): + if hasattr(m, "init_weight"): + m.init_weight() + + # ------------------------------------------------------------------ # + # Kinematic helpers + # ------------------------------------------------------------------ # + + def _ctra_integrate(self, speed, heading, accel, omega, fut_ts, dt, device): + """Midpoint-rule CTRA integration. + + Args: + speed: (bs, N) or (bs,) — current scalar speed. + heading: (bs, N) or (bs,) — current heading in radians. + accel: same shape — longitudinal acceleration (m/s²). + omega: same shape — yaw rate (rad/s). + fut_ts: number of future timesteps. + dt: seconds per timestep. + + Returns: + pred: (*shape, fut_ts, 2) per-step XY deltas. + """ + shape = speed.shape + t = torch.arange(fut_ts, device=device, dtype=speed.dtype) # (fut_ts,) + t_mid = t + 0.5 # midpoint + + # Broadcast: (*shape, 1) × (fut_ts,) → (*shape, fut_ts) + s_mid = (speed.unsqueeze(-1) + accel.unsqueeze(-1) * t_mid * dt).clamp(min=0) + h_mid = heading.unsqueeze(-1) + omega.unsqueeze(-1) * t_mid * dt + + dx = s_mid * torch.cos(h_mid) * dt # (*shape, fut_ts) + dy = s_mid * torch.sin(h_mid) * dt + return torch.stack([dx, dy], dim=-1) # (*shape, fut_ts, 2) + + def _agent_kinematics(self, gt_anchors): + """Return (speed, heading, accel, omega) for each agent anchor. + + Falls back to zero accel/omega if anchor_queue is unavailable (first + frame) or if the corresponding flag is disabled. + + Args: + gt_anchors: (bs, N, 11) encoded GT anchors. + + Returns: + speed, heading, accel, omega — each (bs, N). + """ + vx = gt_anchors[..., VX] + vy = gt_anchors[..., VY] + speed = torch.sqrt(vx ** 2 + vy ** 2).clamp(min=1e-6) + heading = torch.atan2(vy, vx) + + # anchor_queue[-1] is the current frame (just appended by prepare_motion), + # so the actual previous frame is at index -2. + queue = self.instance_queue.anchor_queue # list of (bs, N, 11) tensors + have_history = len(queue) >= 2 + + if self.use_acceleration and have_history: + prev = queue[-2] # (bs, N, 11) — previous frame + prev_speed = torch.sqrt( + prev[..., VX] ** 2 + prev[..., VY] ** 2 + ).clamp(min=1e-6) + accel = (speed - prev_speed) / self.dt + else: + accel = torch.zeros_like(speed) + + if self.use_turn_rate and have_history: + prev = queue[-2] # (bs, N, 11) — previous frame + prev_heading = torch.atan2(prev[..., VY], prev[..., VX]) + omega = (heading - prev_heading) / self.dt + else: + omega = torch.zeros_like(heading) + + return speed, heading, accel, omega + + def _ego_kinematics(self, ego_status, device): + """Return (speed, heading, accel, omega) for the ego vehicle. + + Args: + ego_status: (bs, 9) — index 6/7 = vx/vy in current frame. + + Returns: + speed, heading, accel, omega — each (bs,). + """ + vx = ego_status[:, 6] + vy = ego_status[:, 7] + speed = torch.sqrt(vx ** 2 + vy ** 2).clamp(min=1e-6) + heading = torch.atan2(vy, vx) + + # ego_anchor_queue[-1] stores the previous frame's velocity in VX/VY slots, + # but on the very first frame prev_ego_status was None so VX=VY=0. + # Require len >= 2 so the first frame always falls back to zero accel/omega. + ego_queue = self.instance_queue.ego_anchor_queue # list of (bs, 1, 11) tensors + have_history = len(ego_queue) >= 2 + + if self.use_acceleration and have_history: + prev_ego = ego_queue[-1][:, 0] # (bs, 11) + prev_vx = prev_ego[..., VX] # (bs,) + prev_vy = prev_ego[..., VY] # (bs,) + prev_speed = torch.sqrt(prev_vx ** 2 + prev_vy ** 2).clamp(min=1e-6) + accel = (speed - prev_speed) / self.dt + else: + accel = torch.zeros_like(speed) + + if self.use_turn_rate and have_history: + prev_ego = ego_queue[-1][:, 0] # (bs, 11) + prev_heading = torch.atan2(prev_ego[..., VY], prev_ego[..., VX]) + omega = (heading - prev_heading) / self.dt + else: + omega = torch.zeros_like(heading) + + return speed, heading, accel, omega + + # ------------------------------------------------------------------ # + # Forward + # ------------------------------------------------------------------ # + + def forward( + self, + det_output, + map_output, + feature_maps, + metas, + anchor_encoder, + mask, + anchor_handler, + ): + bs = len(metas["img_metas"]) + gt_anchors = det_output["prediction"][-1] # (bs, num_anchor, 11) + num_anchor = gt_anchors.shape[1] + device = gt_anchors.device + + # Update instance_queue for anchor history tracking. + ego_feature, ego_anchor, _, _, _ = self.instance_queue.get( + det_output, feature_maps, metas, bs, mask, anchor_handler + ) + + # -------- agent motion -------------------------------------------- # + speed, heading, accel, omega = self._agent_kinematics(gt_anchors) + # pred: (bs, N, fut_ts, 2) → expand modes → (bs, N, fut_mode, fut_ts, 2) + pred = self._ctra_integrate(speed, heading, accel, omega, + self.fut_ts, self.dt, device) + motion_pred = ( + pred[:, :, None, :, :] + .expand(bs, num_anchor, self.fut_mode, self.fut_ts, 2) + .contiguous() + ) + motion_cls = gt_anchors.new_zeros(bs, num_anchor, self.fut_mode) + + # -------- ego planning -------------------------------------------- # + ego_status = metas["ego_status"] # (bs, 9) + e_speed, e_heading, e_accel, e_omega = self._ego_kinematics( + ego_status, device + ) + # pred: (bs, ego_fut_ts, 2) → expand → (bs, 3*ego_fut_mode, ego_fut_ts, 2) + ego_pred = self._ctra_integrate(e_speed, e_heading, e_accel, e_omega, + self.ego_fut_ts, self.dt, device) + plan_pred = ( + ego_pred[:, None, :, :] + .expand(bs, 3 * self.ego_fut_mode, self.ego_fut_ts, 2) + .contiguous() + ) + plan_cls = gt_anchors.new_zeros(bs, 3 * self.ego_fut_mode) + plan_status = ego_status.unsqueeze(1) # (bs, 1, 9) + + # -------- update queue state -------------------------------------- # + zero_feats = gt_anchors.new_zeros( + bs, num_anchor, det_output["instance_feature"].shape[-1] + ) + self.instance_queue.cache_motion(zero_feats, det_output, metas) + self.instance_queue.cache_planning(ego_feature, plan_status) + + motion_output = { + "classification": [motion_cls], + "prediction": [motion_pred], + "period": self.instance_queue.period, + "anchor_queue": self.instance_queue.anchor_queue, + } + planning_output = { + "classification": [plan_cls], + "prediction": [plan_pred], + "status": [plan_status], + "period": self.instance_queue.ego_period, + "anchor_queue": self.instance_queue.ego_anchor_queue, + } + return motion_output, planning_output + + # ------------------------------------------------------------------ # + # Loss (computed for monitoring; no params are optimised) + # ------------------------------------------------------------------ # + + @force_fp32(apply_to=("model_outs",)) + def loss(self, motion_model_outs, planning_model_outs, data, motion_loss_cache): + loss = {} + loss.update(self.loss_motion(motion_model_outs, data, motion_loss_cache)) + loss.update(self.loss_planning(planning_model_outs, data)) + return loss + + @force_fp32(apply_to=("model_outs",)) + def loss_motion(self, model_outs, data, motion_loss_cache): + output = {} + for i, (cls, reg) in enumerate(zip( + model_outs["classification"], model_outs["prediction"] + )): + cls_target, cls_weight, reg_pred, reg_target, reg_weight, num_pos = ( + self.motion_sampler.sample( + reg, + data["gt_agent_fut_trajs"], + data["gt_agent_fut_masks"], + motion_loss_cache, + ) + ) + num_pos = max(reduce_mean(num_pos), 1.0) + + cls_loss = self.motion_loss_cls( + cls.flatten(end_dim=1), + cls_target.flatten(end_dim=1), + weight=cls_weight.flatten(end_dim=1), + avg_factor=num_pos, + ) + reg_pred = reg_pred.flatten(end_dim=1).cumsum(dim=-2) + reg_target = reg_target.flatten(end_dim=1).cumsum(dim=-2) + reg_loss = self.motion_loss_reg( + reg_pred, reg_target, + weight=reg_weight.flatten(end_dim=1).unsqueeze(-1), + avg_factor=num_pos, + ) + output[f"motion_loss_cls_{i}"] = cls_loss + output[f"motion_loss_reg_{i}"] = reg_loss + return output + + @force_fp32(apply_to=("model_outs",)) + def loss_planning(self, model_outs, data): + output = {} + for i, (cls, reg, status) in enumerate(zip( + model_outs["classification"], + model_outs["prediction"], + model_outs["status"], + )): + cls, cls_target, cls_weight, reg_pred, reg_target, reg_weight = ( + self.planning_sampler.sample( + cls, reg, + data["gt_ego_fut_trajs"], + data["gt_ego_fut_masks"], + data, + ) + ) + cls_loss = self.plan_loss_cls( + cls.flatten(end_dim=1), + cls_target.flatten(end_dim=1), + weight=cls_weight.flatten(end_dim=1), + ) + reg_loss = self.plan_loss_reg( + reg_pred.flatten(end_dim=1), + reg_target.flatten(end_dim=1), + weight=reg_weight.flatten(end_dim=1).unsqueeze(-1), + ) + status_loss = self.plan_loss_status( + status.squeeze(1), data["ego_status"] + ) + output[f"planning_loss_cls_{i}"] = cls_loss + output[f"planning_loss_reg_{i}"] = reg_loss + output[f"planning_loss_status_{i}"] = status_loss + return output + + # ------------------------------------------------------------------ # + # Post-process + # ------------------------------------------------------------------ # + + @force_fp32(apply_to=("model_outs",)) + def post_process(self, det_output, motion_output, planning_output, data): + motion_result = self.motion_decoder.decode( + det_output["classification"], + det_output["prediction"], + det_output.get("instance_id"), + det_output.get("quality"), + motion_output, + ) + planning_result = self.planning_decoder.decode( + det_output, motion_output, planning_output, data + ) + return motion_result, planning_result diff --git a/projects/mmdet3d_plugin/models/motion/motion_blocks.py b/projects/mmdet3d_plugin/models/motion/motion_blocks.py new file mode 100644 index 0000000..f7e80e2 --- /dev/null +++ b/projects/mmdet3d_plugin/models/motion/motion_blocks.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np + +from mmcv.cnn import Linear, Scale, bias_init_with_prob +from mmcv.runner.base_module import Sequential, BaseModule +from mmcv.cnn import xavier_init +from mmcv.cnn.bricks.registry import ( + PLUGIN_LAYERS, +) + +from projects.mmdet3d_plugin.core.box3d import * +from ..blocks import linear_relu_ln + + +@PLUGIN_LAYERS.register_module() +class MotionPlanningRefinementModule(BaseModule): + def __init__( + self, + embed_dims=256, + fut_ts=12, + fut_mode=6, + ego_fut_ts=6, + ego_fut_mode=3, + ): + super(MotionPlanningRefinementModule, self).__init__() + self.embed_dims = embed_dims + self.fut_ts = fut_ts + self.fut_mode = fut_mode + self.ego_fut_ts = ego_fut_ts + self.ego_fut_mode = ego_fut_mode + + self.motion_cls_branch = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 2), + Linear(embed_dims, 1), + ) + self.motion_reg_branch = nn.Sequential( + nn.Linear(embed_dims, embed_dims), + nn.ReLU(), + nn.Linear(embed_dims, embed_dims), + nn.ReLU(), + nn.Linear(embed_dims, fut_ts * 2), + ) + self.plan_cls_branch = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 2), + Linear(embed_dims, 1), + ) + self.plan_reg_branch = nn.Sequential( + nn.Linear(embed_dims, embed_dims), + nn.ReLU(), + nn.Linear(embed_dims, embed_dims), + nn.ReLU(), + nn.Linear(embed_dims, ego_fut_ts * 2), + ) + self.plan_status_branch = nn.Sequential( + nn.Linear(embed_dims, embed_dims), + nn.ReLU(), + nn.Linear(embed_dims, embed_dims), + nn.ReLU(), + nn.Linear(embed_dims, 10), + ) + + def init_weight(self): + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.motion_cls_branch[-1].bias, bias_init) + nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init) + + def forward( + self, + motion_query, + plan_query, + ego_feature, + ego_anchor_embed, + ): + bs, num_anchor = motion_query.shape[:2] + motion_cls = self.motion_cls_branch(motion_query).squeeze(-1) + motion_reg = self.motion_reg_branch(motion_query).reshape(bs, num_anchor, self.fut_mode, self.fut_ts, 2) + plan_cls = self.plan_cls_branch(plan_query).squeeze(-1) + plan_reg = self.plan_reg_branch(plan_query).reshape(bs, 1, 3 * self.ego_fut_mode, self.ego_fut_ts, 2) + planning_status = self.plan_status_branch(ego_feature + ego_anchor_embed) + return motion_cls, motion_reg, plan_cls, plan_reg, planning_status \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/motion/motion_planning_head.py b/projects/mmdet3d_plugin/models/motion/motion_planning_head.py new file mode 100644 index 0000000..64d0b2b --- /dev/null +++ b/projects/mmdet3d_plugin/models/motion/motion_planning_head.py @@ -0,0 +1,509 @@ +from typing import List, Optional, Tuple, Union +import warnings +import copy + +import numpy as np +import cv2 +import torch +import torch.nn as nn + +from mmcv.utils import build_from_cfg +from mmcv.cnn import Linear, bias_init_with_prob +from mmcv.runner import BaseModule, force_fp32 +from mmcv.cnn.bricks.registry import ( + ATTENTION, + PLUGIN_LAYERS, + POSITIONAL_ENCODING, + FEEDFORWARD_NETWORK, + NORM_LAYERS, +) +from mmdet.core import reduce_mean +from mmdet.models import HEADS +from mmdet.core.bbox.builder import BBOX_SAMPLERS, BBOX_CODERS +from mmdet.models import build_loss + +from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners +from projects.mmdet3d_plugin.core.box3d import * + +from ..attention import gen_sineembed_for_position +from ..blocks import linear_relu_ln +from ..instance_bank import topk + + +@HEADS.register_module() +class MotionPlanningHead(BaseModule): + def __init__( + self, + fut_ts=12, + fut_mode=6, + ego_fut_ts=6, + ego_fut_mode=3, + motion_anchor=None, + plan_anchor=None, + embed_dims=256, + decouple_attn=False, + instance_queue=None, + operation_order=None, + temp_graph_model=None, + graph_model=None, + cross_graph_model=None, + deformable_model=None, + norm_layer=None, + ffn=None, + refine_layer=None, + motion_sampler=None, + motion_loss_cls=None, + motion_loss_reg=None, + planning_sampler=None, + plan_loss_cls=None, + plan_loss_reg=None, + plan_loss_status=None, + motion_decoder=None, + planning_decoder=None, + num_det=50, + num_map=10, + ): + super(MotionPlanningHead, self).__init__() + self.fut_ts = fut_ts + self.fut_mode = fut_mode + self.ego_fut_ts = ego_fut_ts + self.ego_fut_mode = ego_fut_mode + + self.decouple_attn = decouple_attn + self.operation_order = operation_order + + # =========== build modules =========== + def build(cfg, registry): + if cfg is None: + return None + return build_from_cfg(cfg, registry) + + self.instance_queue = build(instance_queue, PLUGIN_LAYERS) + self.motion_sampler = build(motion_sampler, BBOX_SAMPLERS) + self.planning_sampler = build(planning_sampler, BBOX_SAMPLERS) + self.motion_decoder = build(motion_decoder, BBOX_CODERS) + self.planning_decoder = build(planning_decoder, BBOX_CODERS) + self.op_config_map = { + "temp_gnn": [temp_graph_model, ATTENTION], + "gnn": [graph_model, ATTENTION], + "cross_gnn": [cross_graph_model, ATTENTION], + "deformable": [deformable_model, ATTENTION], + "norm": [norm_layer, NORM_LAYERS], + "ffn": [ffn, FEEDFORWARD_NETWORK], + "refine": [refine_layer, PLUGIN_LAYERS], + } + self.layers = nn.ModuleList( + [ + build(*self.op_config_map.get(op, [None, None])) + for op in self.operation_order + ] + ) + self.embed_dims = embed_dims + + if self.decouple_attn: + self.fc_before = nn.Linear( + self.embed_dims, self.embed_dims * 2, bias=False + ) + self.fc_after = nn.Linear( + self.embed_dims * 2, self.embed_dims, bias=False + ) + else: + self.fc_before = nn.Identity() + self.fc_after = nn.Identity() + + self.motion_loss_cls = build_loss(motion_loss_cls) + self.motion_loss_reg = build_loss(motion_loss_reg) + self.plan_loss_cls = build_loss(plan_loss_cls) + self.plan_loss_reg = build_loss(plan_loss_reg) + self.plan_loss_status = build_loss(plan_loss_status) + + # motion init + motion_anchor = np.load(motion_anchor) + self.motion_anchor = nn.Parameter( + torch.tensor(motion_anchor, dtype=torch.float32), + requires_grad=False, + ) + self.motion_anchor_encoder = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 1), + Linear(embed_dims, embed_dims), + ) + + # plan anchor init + plan_anchor = np.load(plan_anchor) + self.plan_anchor = nn.Parameter( + torch.tensor(plan_anchor, dtype=torch.float32), + requires_grad=False, + ) + self.plan_anchor_encoder = nn.Sequential( + *linear_relu_ln(embed_dims, 1, 1), + Linear(embed_dims, embed_dims), + ) + + self.num_det = num_det + self.num_map = num_map + + def init_weights(self): + for i, op in enumerate(self.operation_order): + if self.layers[i] is None: + continue + elif op != "refine": + for p in self.layers[i].parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if hasattr(m, "init_weight"): + m.init_weight() + + def get_motion_anchor( + self, + classification, + prediction, + ): + cls_ids = classification.argmax(dim=-1) + motion_anchor = self.motion_anchor[cls_ids] + prediction = prediction.detach() + return self._agent2lidar(motion_anchor, prediction) + + def _agent2lidar(self, trajs, boxes): + yaw = torch.atan2(boxes[..., SIN_YAW], boxes[..., COS_YAW]) + cos_yaw = torch.cos(yaw) + sin_yaw = torch.sin(yaw) + rot_mat_T = torch.stack( + [ + torch.stack([cos_yaw, sin_yaw]), + torch.stack([-sin_yaw, cos_yaw]), + ] + ) + + trajs_lidar = torch.einsum('abcij,jkab->abcik', trajs, rot_mat_T) + return trajs_lidar + + def graph_model( + self, + index, + query, + key=None, + value=None, + query_pos=None, + key_pos=None, + **kwargs, + ): + if self.decouple_attn: + query = torch.cat([query, query_pos], dim=-1) + if key is not None: + key = torch.cat([key, key_pos], dim=-1) + query_pos, key_pos = None, None + if value is not None: + value = self.fc_before(value) + return self.fc_after( + self.layers[index]( + query, + key, + value, + query_pos=query_pos, + key_pos=key_pos, + **kwargs, + ) + ) + + def forward( + self, + det_output, + map_output, + feature_maps, + metas, + anchor_encoder, + mask, + anchor_handler, + ): + # =========== det/map feature/anchor =========== + instance_feature = det_output["instance_feature"] + anchor_embed = det_output["anchor_embed"] + det_classification = det_output["classification"][-1].sigmoid() + det_anchors = det_output["prediction"][-1] + det_confidence = det_classification.max(dim=-1).values + _, (instance_feature_selected, anchor_embed_selected) = topk( + det_confidence, self.num_det, instance_feature, anchor_embed + ) + + if map_output is not None: + map_instance_feature = map_output["instance_feature"] + map_anchor_embed = map_output["anchor_embed"] + map_classification = map_output["classification"][-1].sigmoid() + map_anchors = map_output["prediction"][-1] + map_confidence = map_classification.max(dim=-1).values + _, (map_instance_feature_selected, map_anchor_embed_selected) = topk( + map_confidence, self.num_map, map_instance_feature, map_anchor_embed + ) + + # =========== get ego/temporal feature/anchor =========== + bs, num_anchor, dim = instance_feature.shape + ( + ego_feature, + ego_anchor, + temp_instance_feature, + temp_anchor, + temp_mask, + ) = self.instance_queue.get( + det_output, + feature_maps, + metas, + bs, + mask, + anchor_handler, + ) + ego_anchor_embed = anchor_encoder(ego_anchor) + temp_anchor_embed = anchor_encoder(temp_anchor) + temp_instance_feature = temp_instance_feature.flatten(0, 1) + temp_anchor_embed = temp_anchor_embed.flatten(0, 1) + temp_mask = temp_mask.flatten(0, 1) + + # =========== mode anchor init =========== + motion_anchor = self.get_motion_anchor(det_classification, det_anchors) + plan_anchor = torch.tile( + self.plan_anchor[None], (bs, 1, 1, 1, 1) + ) + + # =========== mode query init =========== + motion_mode_query = self.motion_anchor_encoder(gen_sineembed_for_position(motion_anchor[..., -1, :])) + plan_pos = gen_sineembed_for_position(plan_anchor[..., -1, :]) + plan_mode_query = self.plan_anchor_encoder(plan_pos).flatten(1, 2).unsqueeze(1) + + # =========== cat instance and ego =========== + instance_feature_selected = torch.cat([instance_feature_selected, ego_feature], dim=1) + anchor_embed_selected = torch.cat([anchor_embed_selected, ego_anchor_embed], dim=1) + + instance_feature = torch.cat([instance_feature, ego_feature], dim=1) + anchor_embed = torch.cat([anchor_embed, ego_anchor_embed], dim=1) + + # =================== forward the layers ==================== + motion_classification = [] + motion_prediction = [] + planning_classification = [] + planning_prediction = [] + planning_status = [] + for i, op in enumerate(self.operation_order): + if self.layers[i] is None: + continue + elif op == "temp_gnn": + instance_feature = self.graph_model( + i, + instance_feature.flatten(0, 1).unsqueeze(1), + temp_instance_feature, + temp_instance_feature, + query_pos=anchor_embed.flatten(0, 1).unsqueeze(1), + key_pos=temp_anchor_embed, + key_padding_mask=temp_mask, + ) + instance_feature = instance_feature.reshape(bs, num_anchor + 1, dim) + elif op == "gnn": + instance_feature = self.graph_model( + i, + instance_feature, + instance_feature_selected, + instance_feature_selected, + query_pos=anchor_embed, + key_pos=anchor_embed_selected, + ) + elif op == "norm" or op == "ffn": + instance_feature = self.layers[i](instance_feature) + elif op == "cross_gnn": + instance_feature = self.layers[i]( + instance_feature, + key=map_instance_feature_selected, + query_pos=anchor_embed, + key_pos=map_anchor_embed_selected, + ) + elif op == "deformable": + # Apply deformable cross-attention to sensor features for + # agent instances only (ego token has no well-defined 3D box). + agent_feature = self.layers[i]( + instance_feature[:, :num_anchor], + det_anchors, + anchor_embed[:, :num_anchor], + feature_maps, + metas, + ) + instance_feature = torch.cat( + [agent_feature, instance_feature[:, num_anchor:]], dim=1 + ) + elif op == "refine": + motion_query = motion_mode_query + (instance_feature + anchor_embed)[:, :num_anchor].unsqueeze(2) + plan_query = plan_mode_query + (instance_feature + anchor_embed)[:, num_anchor:].unsqueeze(2) + ( + motion_cls, + motion_reg, + plan_cls, + plan_reg, + plan_status, + ) = self.layers[i]( + motion_query, + plan_query, + instance_feature[:, num_anchor:], + anchor_embed[:, num_anchor:], + ) + motion_classification.append(motion_cls) + motion_prediction.append(motion_reg) + planning_classification.append(plan_cls) + planning_prediction.append(plan_reg) + planning_status.append(plan_status) + # Update mode anchor queries for the next decoder iteration. + # cumsum converts delta trajectories to absolute endpoints. + motion_anchor_upd = motion_reg.detach().cumsum(dim=-2) + motion_mode_query = self.motion_anchor_encoder( + gen_sineembed_for_position(motion_anchor_upd[..., -1, :]) + ) + plan_anchor_upd = plan_reg.detach().cumsum(dim=-2) + plan_mode_query = self.plan_anchor_encoder( + gen_sineembed_for_position(plan_anchor_upd[..., -1, :]) + ).flatten(1, 2).unsqueeze(1) + + self.instance_queue.cache_motion(instance_feature[:, :num_anchor], det_output, metas) + self.instance_queue.cache_planning(instance_feature[:, num_anchor:], plan_status) + + motion_output = { + "classification": motion_classification, + "prediction": motion_prediction, + "period": self.instance_queue.period, + "anchor_queue": self.instance_queue.anchor_queue, + } + planning_output = { + "classification": planning_classification, + "prediction": planning_prediction, + "status": planning_status, + "period": self.instance_queue.ego_period, + "anchor_queue": self.instance_queue.ego_anchor_queue, + } + return motion_output, planning_output + + def loss(self, + motion_model_outs, + planning_model_outs, + data, + motion_loss_cache + ): + loss = {} + motion_loss = self.loss_motion(motion_model_outs, data, motion_loss_cache) + loss.update(motion_loss) + planning_loss = self.loss_planning(planning_model_outs, data) + loss.update(planning_loss) + return loss + + @force_fp32(apply_to=("model_outs")) + def loss_motion(self, model_outs, data, motion_loss_cache): + cls_scores = model_outs["classification"] + reg_preds = model_outs["prediction"] + output = {} + for decoder_idx, (cls, reg) in enumerate( + zip(cls_scores, reg_preds) + ): + ( + cls_target, + cls_weight, + reg_pred, + reg_target, + reg_weight, + num_pos + ) = self.motion_sampler.sample( + reg, + data["gt_agent_fut_trajs"], + data["gt_agent_fut_masks"], + motion_loss_cache, + ) + num_pos = max(reduce_mean(num_pos), 1.0) + + cls = cls.flatten(end_dim=1) + cls_target = cls_target.flatten(end_dim=1) + cls_weight = cls_weight.flatten(end_dim=1) + cls_loss = self.motion_loss_cls(cls, cls_target, weight=cls_weight, avg_factor=num_pos) + + reg_weight = reg_weight.flatten(end_dim=1) + reg_pred = reg_pred.flatten(end_dim=1) + reg_target = reg_target.flatten(end_dim=1) + reg_weight = reg_weight.unsqueeze(-1) + reg_pred = reg_pred.cumsum(dim=-2) + reg_target = reg_target.cumsum(dim=-2) + reg_loss = self.motion_loss_reg( + reg_pred, reg_target, weight=reg_weight, avg_factor=num_pos + ) + + output.update( + { + f"motion_loss_cls_{decoder_idx}": cls_loss, + f"motion_loss_reg_{decoder_idx}": reg_loss, + } + ) + + return output + + @force_fp32(apply_to=("model_outs")) + def loss_planning(self, model_outs, data): + cls_scores = model_outs["classification"] + reg_preds = model_outs["prediction"] + status_preds = model_outs["status"] + output = {} + for decoder_idx, (cls, reg, status) in enumerate( + zip(cls_scores, reg_preds, status_preds) + ): + ( + cls, + cls_target, + cls_weight, + reg_pred, + reg_target, + reg_weight, + ) = self.planning_sampler.sample( + cls, + reg, + data['gt_ego_fut_trajs'], + data['gt_ego_fut_masks'], + data, + ) + cls = cls.flatten(end_dim=1) + cls_target = cls_target.flatten(end_dim=1) + cls_weight = cls_weight.flatten(end_dim=1) + cls_loss = self.plan_loss_cls(cls, cls_target, weight=cls_weight) + + reg_weight = reg_weight.flatten(end_dim=1) + reg_pred = reg_pred.flatten(end_dim=1) + reg_target = reg_target.flatten(end_dim=1) + reg_weight = reg_weight.unsqueeze(-1) + + reg_loss = self.plan_loss_reg( + reg_pred, reg_target, weight=reg_weight + ) + status_loss = self.plan_loss_status(status.squeeze(1), data['ego_status']) + + output.update( + { + f"planning_loss_cls_{decoder_idx}": cls_loss, + f"planning_loss_reg_{decoder_idx}": reg_loss, + f"planning_loss_status_{decoder_idx}": status_loss, + } + ) + + return output + + @force_fp32(apply_to=("model_outs")) + def post_process( + self, + det_output, + motion_output, + planning_output, + data, + ): + motion_result = self.motion_decoder.decode( + det_output["classification"], + det_output["prediction"], + det_output.get("instance_id"), + det_output.get("quality"), + motion_output, + ) + planning_result = self.planning_decoder.decode( + det_output, + motion_output, + planning_output, + data, + ) + + return motion_result, planning_result \ No newline at end of file diff --git a/projects/mmdet3d_plugin/models/motion/target.py b/projects/mmdet3d_plugin/models/motion/target.py new file mode 100644 index 0000000..521044f --- /dev/null +++ b/projects/mmdet3d_plugin/models/motion/target.py @@ -0,0 +1,108 @@ +import torch + +from mmdet.core.bbox.builder import BBOX_SAMPLERS + +__all__ = ["MotionTarget", "PlanningTarget"] + + +def get_cls_target( + reg_preds, + reg_target, + reg_weight, +): + bs, num_pred, mode, ts, d = reg_preds.shape + reg_preds_cum = reg_preds.cumsum(dim=-2) + reg_target_cum = reg_target.cumsum(dim=-2) + dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1) + dist = dist * reg_weight.unsqueeze(2) + dist = dist.mean(dim=-1) + mode_idx = torch.argmin(dist, dim=-1) + return mode_idx + +def get_best_reg( + reg_preds, + reg_target, + reg_weight, +): + bs, num_pred, mode, ts, d = reg_preds.shape + reg_preds_cum = reg_preds.cumsum(dim=-2) + reg_target_cum = reg_target.cumsum(dim=-2) + dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1) + dist = dist * reg_weight.unsqueeze(2) + dist = dist.mean(dim=-1) + mode_idx = torch.argmin(dist, dim=-1) + mode_idx = mode_idx[..., None, None, None].repeat(1, 1, 1, ts, d) + best_reg = torch.gather(reg_preds, 2, mode_idx).squeeze(2) + return best_reg + + +@BBOX_SAMPLERS.register_module() +class MotionTarget(): + def __init__( + self, + ): + super(MotionTarget, self).__init__() + + def sample( + self, + reg_pred, + gt_reg_target, + gt_reg_mask, + motion_loss_cache, + ): + bs, num_anchor, mode, ts, d = reg_pred.shape + reg_target = reg_pred.new_zeros((bs, num_anchor, ts, d)) + reg_weight = reg_pred.new_zeros((bs, num_anchor, ts)) + indices = motion_loss_cache['indices'] + num_pos = reg_pred.new_tensor([0]) + for i, (pred_idx, target_idx) in enumerate(indices): + if len(gt_reg_target[i]) == 0: + continue + reg_target[i, pred_idx] = gt_reg_target[i][target_idx] + reg_weight[i, pred_idx] = gt_reg_mask[i][target_idx] + num_pos += len(pred_idx) + + cls_target = get_cls_target(reg_pred, reg_target, reg_weight) + cls_weight = reg_weight.any(dim=-1) + best_reg = get_best_reg(reg_pred, reg_target, reg_weight) + + return cls_target, cls_weight, best_reg, reg_target, reg_weight, num_pos + + +@BBOX_SAMPLERS.register_module() +class PlanningTarget(): + def __init__( + self, + ego_fut_ts, + ego_fut_mode, + ): + super(PlanningTarget, self).__init__() + self.ego_fut_ts = ego_fut_ts + self.ego_fut_mode = ego_fut_mode + + def sample( + self, + cls_pred, + reg_pred, + gt_reg_target, + gt_reg_mask, + data, + ): + gt_reg_target = gt_reg_target.unsqueeze(1) + gt_reg_mask = gt_reg_mask.unsqueeze(1) + + bs = reg_pred.shape[0] + bs_indices = torch.arange(bs, device=reg_pred.device) + cmd = data['gt_ego_fut_cmd'].argmax(dim=-1) + + cls_pred = cls_pred.reshape(bs, 3, 1, self.ego_fut_mode) + reg_pred = reg_pred.reshape(bs, 3, 1, self.ego_fut_mode, self.ego_fut_ts, 2) + cls_pred = cls_pred[bs_indices, cmd] + reg_pred = reg_pred[bs_indices, cmd] + cls_target = get_cls_target(reg_pred, gt_reg_target, gt_reg_mask) + cls_weight = gt_reg_mask.any(dim=-1) + best_reg = get_best_reg(reg_pred, gt_reg_target, gt_reg_mask) + + return cls_pred, cls_target, cls_weight, best_reg, gt_reg_target, gt_reg_mask + + diff --git a/projects/mmdet3d_plugin/models/sparsedrive.py b/projects/mmdet3d_plugin/models/sparsedrive.py new file mode 100644 index 0000000..16837dd --- /dev/null +++ b/projects/mmdet3d_plugin/models/sparsedrive.py @@ -0,0 +1,127 @@ +from inspect import signature + +import torch + +from mmcv.runner import force_fp32, auto_fp16 +from mmcv.utils import build_from_cfg +from mmcv.cnn.bricks.registry import PLUGIN_LAYERS +from mmdet.models import ( + DETECTORS, + BaseDetector, + build_backbone, + build_head, + build_neck, +) +from .grid_mask import GridMask + +try: + from ..ops import feature_maps_format + DAF_VALID = True +except: + DAF_VALID = False + +__all__ = ["SparseDrive"] + + +@DETECTORS.register_module() +class SparseDrive(BaseDetector): + def __init__( + self, + img_backbone, + head, + img_neck=None, + init_cfg=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + use_grid_mask=True, + use_deformable_func=False, + depth_branch=None, + ): + super(SparseDrive, self).__init__(init_cfg=init_cfg) + if pretrained is not None: + backbone.pretrained = pretrained + self.img_backbone = build_backbone(img_backbone) + if img_neck is not None: + self.img_neck = build_neck(img_neck) + self.head = build_head(head) + self.use_grid_mask = use_grid_mask + if use_deformable_func: + assert DAF_VALID, "deformable_aggregation needs to be set up." + self.use_deformable_func = use_deformable_func + if depth_branch is not None: + self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS) + else: + self.depth_branch = None + if use_grid_mask: + self.grid_mask = GridMask( + True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7 + ) + + @auto_fp16(apply_to=("img",), out_fp32=True) + def extract_feat(self, img, return_depth=False, metas=None): + bs = img.shape[0] + if img.dim() == 5: # multi-view + num_cams = img.shape[1] + img = img.flatten(end_dim=1) + else: + num_cams = 1 + if self.use_grid_mask: + img = self.grid_mask(img) + if "metas" in signature(self.img_backbone.forward).parameters: + feature_maps = self.img_backbone(img, num_cams, metas=metas) + else: + feature_maps = self.img_backbone(img) + if self.img_neck is not None: + feature_maps = list(self.img_neck(feature_maps)) + for i, feat in enumerate(feature_maps): + feature_maps[i] = torch.reshape( + feat, (bs, num_cams) + feat.shape[1:] + ) + if return_depth and self.depth_branch is not None: + depths = self.depth_branch(feature_maps, metas.get("focal")) + else: + depths = None + if self.use_deformable_func: + feature_maps = feature_maps_format(feature_maps) + if return_depth: + return feature_maps, depths + return feature_maps + + @force_fp32(apply_to=("img",)) + def forward(self, img, **data): + if self.training: + return self.forward_train(img, **data) + else: + return self.forward_test(img, **data) + + def forward_train(self, img, **data): + feature_maps, depths = self.extract_feat(img, True, data) + model_outs = self.head(feature_maps, data) + output = self.head.loss(model_outs, data) + if depths is not None and "gt_depth" in data: + output["loss_dense_depth"] = self.depth_branch.loss( + depths, data["gt_depth"] + ) + return output + + def forward_test(self, img, **data): + if isinstance(img, list): + return self.aug_test(img, **data) + else: + return self.simple_test(img, **data) + + def simple_test(self, img, **data): + feature_maps = self.extract_feat(img) + + model_outs = self.head(feature_maps, data) + results = self.head.post_process(model_outs, data) + output = [{"img_bbox": result} for result in results] + return output + + def aug_test(self, img, **data): + # fake test time augmentation + for key in data.keys(): + if isinstance(data[key], list): + data[key] = data[key][0] + return self.simple_test(img[0], **data) diff --git a/projects/mmdet3d_plugin/models/sparsedrive_head.py b/projects/mmdet3d_plugin/models/sparsedrive_head.py new file mode 100644 index 0000000..79eec57 --- /dev/null +++ b/projects/mmdet3d_plugin/models/sparsedrive_head.py @@ -0,0 +1,124 @@ +from typing import List, Optional, Tuple, Union +import warnings + +import numpy as np +import torch +import torch.nn as nn + +from mmcv.runner import BaseModule +from mmdet.models import HEADS +from mmdet.models import build_head + + +@HEADS.register_module() +class SparseDriveHead(BaseModule): + def __init__( + self, + task_config: dict, + det_head = dict, + map_head = dict, + motion_plan_head = dict, + init_cfg=None, + **kwargs, + ): + super(SparseDriveHead, self).__init__(init_cfg) + self.task_config = task_config + if self.task_config['with_det']: + self.det_head = build_head(det_head) + if self.task_config['with_map']: + self.map_head = build_head(map_head) + if self.task_config['with_motion_plan']: + self.motion_plan_head = build_head(motion_plan_head) + + def init_weights(self): + if self.task_config['with_det']: + self.det_head.init_weights() + if self.task_config['with_map']: + self.map_head.init_weights() + if self.task_config['with_motion_plan']: + self.motion_plan_head.init_weights() + + def forward( + self, + feature_maps: Union[torch.Tensor, List], + metas: dict, + ): + if self.task_config['with_det']: + det_output = self.det_head(feature_maps, metas) + else: + det_output = None + + if self.task_config['with_map']: + map_output = self.map_head(feature_maps, metas) + else: + map_output = None + + if self.task_config['with_motion_plan']: + motion_output, planning_output = self.motion_plan_head( + det_output, + map_output, + feature_maps, + metas, + self.det_head.anchor_encoder, + self.det_head.instance_bank.mask, + self.det_head.instance_bank.anchor_handler, + ) + else: + motion_output, planning_output = None, None + + return det_output, map_output, motion_output, planning_output + + def loss(self, model_outs, data): + det_output, map_output, motion_output, planning_output = model_outs + losses = dict() + if self.task_config['with_det']: + loss_det = self.det_head.loss(det_output, data) + losses.update(loss_det) + + if self.task_config['with_map']: + loss_map = self.map_head.loss(map_output, data) + losses.update(loss_map) + + if self.task_config['with_motion_plan']: + motion_loss_cache = dict( + indices=self.det_head.sampler.indices, + ) + loss_motion = self.motion_plan_head.loss( + motion_output, + planning_output, + data, + motion_loss_cache + ) + losses.update(loss_motion) + + return losses + + def post_process(self, model_outs, data): + det_output, map_output, motion_output, planning_output = model_outs + if self.task_config['with_det']: + det_result = self.det_head.post_process(det_output) + batch_size = len(det_result) + + if self.task_config['with_map']: + map_result= self.map_head.post_process(map_output) + batch_size = len(map_result) + + if self.task_config['with_motion_plan']: + motion_result, planning_result = self.motion_plan_head.post_process( + det_output, + motion_output, + planning_output, + data, + ) + + results = [dict()] * batch_size + for i in range(batch_size): + if self.task_config['with_det']: + results[i].update(det_result[i]) + if self.task_config['with_map']: + results[i].update(map_result[i]) + if self.task_config['with_motion_plan']: + results[i].update(motion_result[i]) + results[i].update(planning_result[i]) + + return results diff --git a/projects/mmdet3d_plugin/ops/__init__.py b/projects/mmdet3d_plugin/ops/__init__.py new file mode 100644 index 0000000..cf23848 --- /dev/null +++ b/projects/mmdet3d_plugin/ops/__init__.py @@ -0,0 +1,92 @@ +import torch + +from .deformable_aggregation import DeformableAggregationFunction + + +def deformable_aggregation_function( + feature_maps, + spatial_shape, + scale_start_index, + sampling_location, + weights, +): + return DeformableAggregationFunction.apply( + feature_maps, + spatial_shape, + scale_start_index, + sampling_location, + weights, + ) + + +def feature_maps_format(feature_maps, inverse=False): + if inverse: + col_feats, spatial_shape, scale_start_index = feature_maps + num_cams, num_levels = spatial_shape.shape[:2] + + split_size = spatial_shape[..., 0] * spatial_shape[..., 1] + split_size = split_size.cpu().numpy().tolist() + + idx = 0 + cam_split = [1] + cam_split_size = [sum(split_size[0])] + for i in range(num_cams - 1): + if not torch.all(spatial_shape[i] == spatial_shape[i + 1]): + cam_split.append(0) + cam_split_size.append(0) + cam_split[-1] += 1 + cam_split_size[-1] += sum(split_size[i + 1]) + mc_feat = [ + x.unflatten(1, (cam_split[i], -1)) + for i, x in enumerate(col_feats.split(cam_split_size, dim=1)) + ] + + spatial_shape = spatial_shape.cpu().numpy().tolist() + mc_ms_feat = [] + shape_index = 0 + for i, feat in enumerate(mc_feat): + feat = list(feat.split(split_size[shape_index], dim=2)) + for j, f in enumerate(feat): + feat[j] = f.unflatten(2, spatial_shape[shape_index][j]) + feat[j] = feat[j].permute(0, 1, 4, 2, 3) + mc_ms_feat.append(feat) + shape_index += cam_split[i] + return mc_ms_feat + + if isinstance(feature_maps[0], (list, tuple)): + formated = [feature_maps_format(x) for x in feature_maps] + col_feats = torch.cat([x[0] for x in formated], dim=1) + spatial_shape = torch.cat([x[1] for x in formated], dim=0) + scale_start_index = torch.cat([x[2] for x in formated], dim=0) + return [col_feats, spatial_shape, scale_start_index] + + bs, num_cams = feature_maps[0].shape[:2] + spatial_shape = [] + + col_feats = [] + for i, feat in enumerate(feature_maps): + spatial_shape.append(feat.shape[-2:]) + col_feats.append( + torch.reshape(feat, (bs, num_cams, feat.shape[2], -1)) + ) + + col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2) + spatial_shape = [spatial_shape] * num_cams + spatial_shape = torch.tensor( + spatial_shape, + dtype=torch.int64, + device=col_feats.device, + ) + scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1] + scale_start_index = scale_start_index.flatten().cumsum(dim=0) + scale_start_index = torch.cat( + [torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]] + ) + scale_start_index = scale_start_index.reshape(num_cams, -1) + + feature_maps = [ + col_feats, + spatial_shape, + scale_start_index, + ] + return feature_maps diff --git a/projects/mmdet3d_plugin/ops/deformable_aggregation.py b/projects/mmdet3d_plugin/ops/deformable_aggregation.py new file mode 100644 index 0000000..fe11b69 --- /dev/null +++ b/projects/mmdet3d_plugin/ops/deformable_aggregation.py @@ -0,0 +1,75 @@ +import torch +from torch.autograd.function import Function, once_differentiable + +from . import deformable_aggregation_ext + + +class DeformableAggregationFunction(Function): + @staticmethod + def forward( + ctx, + mc_ms_feat, + spatial_shape, + scale_start_index, + sampling_location, + weights, + ): + # output: [bs, num_pts, num_embeds] + mc_ms_feat = mc_ms_feat.contiguous().float() + spatial_shape = spatial_shape.contiguous().int() + scale_start_index = scale_start_index.contiguous().int() + sampling_location = sampling_location.contiguous().float() + weights = weights.contiguous().float() + output = deformable_aggregation_ext.deformable_aggregation_forward( + mc_ms_feat, + spatial_shape, + scale_start_index, + sampling_location, + weights, + ) + ctx.save_for_backward( + mc_ms_feat, + spatial_shape, + scale_start_index, + sampling_location, + weights, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + ( + mc_ms_feat, + spatial_shape, + scale_start_index, + sampling_location, + weights, + ) = ctx.saved_tensors + mc_ms_feat = mc_ms_feat.contiguous().float() + spatial_shape = spatial_shape.contiguous().int() + scale_start_index = scale_start_index.contiguous().int() + sampling_location = sampling_location.contiguous().float() + weights = weights.contiguous().float() + + grad_mc_ms_feat = torch.zeros_like(mc_ms_feat) + grad_sampling_location = torch.zeros_like(sampling_location) + grad_weights = torch.zeros_like(weights) + deformable_aggregation_ext.deformable_aggregation_backward( + mc_ms_feat, + spatial_shape, + scale_start_index, + sampling_location, + weights, + grad_output.contiguous(), + grad_mc_ms_feat, + grad_sampling_location, + grad_weights, + ) + return ( + grad_mc_ms_feat, + None, + None, + grad_sampling_location, + grad_weights, + ) diff --git a/projects/mmdet3d_plugin/ops/setup.py b/projects/mmdet3d_plugin/ops/setup.py new file mode 100644 index 0000000..cbade27 --- /dev/null +++ b/projects/mmdet3d_plugin/ops/setup.py @@ -0,0 +1,60 @@ +import os + +import torch +from setuptools import setup +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, +) + + +def make_cuda_ext( + name, + module, + sources, + sources_cuda=[], + extra_args=[], + extra_include_path=[], +): + + define_macros = [] + extra_compile_args = {"cxx": [] + extra_args} + + if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1": + define_macros += [("WITH_CUDA", None)] + extension = CUDAExtension + extra_compile_args["nvcc"] = extra_args + [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + sources += sources_cuda + else: + print("Compiling {} without CUDA".format(name)) + extension = CppExtension + + return extension( + name="{}.{}".format(module, name), + sources=[os.path.join(*module.split("."), p) for p in sources], + include_dirs=extra_include_path, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + + +if __name__ == "__main__": + setup( + name="deformable_aggregation_ext", + ext_modules=[ + make_cuda_ext( + "deformable_aggregation_ext", + module=".", + sources=[ + f"src/deformable_aggregation.cpp", + f"src/deformable_aggregation_cuda.cu", + ], + ), + ], + cmdclass={"build_ext": BuildExtension}, + ) diff --git a/projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp b/projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp new file mode 100644 index 0000000..68356a7 --- /dev/null +++ b/projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp @@ -0,0 +1,138 @@ +#include +#include + +void deformable_aggregation( + float* output, + const float* mc_ms_feat, + const int* spatial_shape, + const int* scale_start_index, + const float* sample_location, + const float* weights, + int batch_size, + int num_cams, + int num_feat, + int num_embeds, + int num_scale, + int num_anchors, + int num_pts, + int num_groups +); + + +/* feat: bs, num_feat, c */ +/* _spatial_shape: cam, scale, 2 */ +/* _scale_start_index: cam, scale */ +/* _sampling_location: bs, anchor, pts, cam, 2 */ +/* _weights: bs, anchor, pts, cam, scale, group */ +/* output: bs, anchor, c */ +/* kernel: bs, anchor, pts, c */ + + +at::Tensor deformable_aggregation_forward( + const at::Tensor &_mc_ms_feat, + const at::Tensor &_spatial_shape, + const at::Tensor &_scale_start_index, + const at::Tensor &_sampling_location, + const at::Tensor &_weights +) { + at::DeviceGuard guard(_mc_ms_feat.device()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); + int batch_size = _mc_ms_feat.size(0); + int num_feat = _mc_ms_feat.size(1); + int num_embeds = _mc_ms_feat.size(2); + int num_cams = _spatial_shape.size(0); + int num_scale = _spatial_shape.size(1); + int num_anchors = _sampling_location.size(1); + int num_pts = _sampling_location.size(2); + int num_groups = _weights.size(5); + + const float* mc_ms_feat = _mc_ms_feat.data_ptr(); + const int* spatial_shape = _spatial_shape.data_ptr(); + const int* scale_start_index = _scale_start_index.data_ptr(); + const float* sampling_location = _sampling_location.data_ptr(); + const float* weights = _weights.data_ptr(); + + auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options()); + deformable_aggregation( + output.data_ptr(), + mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, + batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups + ); + return output; +} + + +void deformable_aggregation_grad( + const float* mc_ms_feat, + const int* spatial_shape, + const int* scale_start_index, + const float* sample_location, + const float* weights, + const float* grad_output, + float* grad_mc_ms_feat, + float* grad_sampling_location, + float* grad_weights, + int batch_size, + int num_cams, + int num_feat, + int num_embeds, + int num_scale, + int num_anchors, + int num_pts, + int num_groups +); + + +void deformable_aggregation_backward( + const at::Tensor &_mc_ms_feat, + const at::Tensor &_spatial_shape, + const at::Tensor &_scale_start_index, + const at::Tensor &_sampling_location, + const at::Tensor &_weights, + const at::Tensor &_grad_output, + at::Tensor &_grad_mc_ms_feat, + at::Tensor &_grad_sampling_location, + at::Tensor &_grad_weights +) { + at::DeviceGuard guard(_mc_ms_feat.device()); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); + int batch_size = _mc_ms_feat.size(0); + int num_feat = _mc_ms_feat.size(1); + int num_embeds = _mc_ms_feat.size(2); + int num_cams = _spatial_shape.size(0); + int num_scale = _spatial_shape.size(1); + int num_anchors = _sampling_location.size(1); + int num_pts = _sampling_location.size(2); + int num_groups = _weights.size(5); + + const float* mc_ms_feat = _mc_ms_feat.data_ptr(); + const int* spatial_shape = _spatial_shape.data_ptr(); + const int* scale_start_index = _scale_start_index.data_ptr(); + const float* sampling_location = _sampling_location.data_ptr(); + const float* weights = _weights.data_ptr(); + const float* grad_output = _grad_output.data_ptr(); + + float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr(); + float* grad_sampling_location = _grad_sampling_location.data_ptr(); + float* grad_weights = _grad_weights.data_ptr(); + + deformable_aggregation_grad( + mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, + grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights, + batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups + ); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "deformable_aggregation_forward", + &deformable_aggregation_forward, + "deformable_aggregation_forward" + ); + m.def( + "deformable_aggregation_backward", + &deformable_aggregation_backward, + "deformable_aggregation_backward" + ); +} diff --git a/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu b/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu new file mode 100644 index 0000000..4f748e5 --- /dev/null +++ b/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu @@ -0,0 +1,318 @@ + +#include +#include +#include +#include + +#include + +#include +#include + + +__device__ float bilinear_sampling( + const float *&bottom_data, const int &height, const int &width, + const int &num_embeds, const float &h_im, const float &w_im, + const int &base_ptr +) { + const int h_low = floorf(h_im); + const int w_low = floorf(w_im); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const float lh = h_im - h_low; + const float lw = w_im - w_low; + const float hh = 1 - lh, hw = 1 - lw; + + const int w_stride = num_embeds; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } + + const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + + +__device__ void bilinear_sampling_grad( + const float *&bottom_data, const float &weight, + const int &height, const int &width, + const int &num_embeds, const float &h_im, const float &w_im, + const int &base_ptr, + const float &grad_output, + float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) { + const int h_low = floorf(h_im); + const int w_low = floorf(w_im); + const int h_high = h_low + 1; + const int w_high = w_low + 1; + + const float lh = h_im - h_low; + const float lw = w_im - w_low; + const float hh = 1 - lh, hw = 1 - lw; + + const int w_stride = num_embeds; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + + const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const float top_grad_mc_ms_feat = grad_output * weight; + float grad_h_weight = 0, grad_w_weight = 0; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + grad_h_weight -= hw * v1; + grad_w_weight -= hh * v1; + atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat); + } + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + grad_h_weight -= lw * v2; + grad_w_weight += hh * v2; + atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat); + } + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + grad_h_weight += hw * v3; + grad_w_weight -= lh * v3; + atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat); + } + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + grad_h_weight += lw * v4; + grad_w_weight += lh * v4; + atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat); + } + + const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + atomicAdd(grad_weights, grad_output * val); + atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat); + atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat); +} + + +__global__ void deformable_aggregation_kernel( + const int num_kernels, + float* output, + const float* mc_ms_feat, + const int* spatial_shape, + const int* scale_start_index, + const float* sample_location, + const float* weights, + int batch_size, + int num_cams, + int num_feat, + int num_embeds, + int num_scale, + int num_anchors, + int num_pts, + int num_groups +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_kernels) return; + + const float weight = *(weights + idx / (num_embeds / num_groups)); + const int channel_index = idx % num_embeds; + idx /= num_embeds; + const int scale_index = idx % num_scale; + idx /= num_scale; + + const int cam_index = idx % num_cams; + idx /= num_cams; + const int pts_index = idx % num_pts; + idx /= num_pts; + + int anchor_index = idx % num_anchors; + idx /= num_anchors; + const int batch_index = idx % batch_size; + idx /= batch_size; + + anchor_index = batch_index * num_anchors + anchor_index; + const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1; + + const float loc_w = sample_location[loc_offset]; + if (loc_w <= 0 || loc_w >= 1) return; + const float loc_h = sample_location[loc_offset + 1]; + if (loc_h <= 0 || loc_h >= 1) return; + + int cam_scale_index = cam_index * num_scale + scale_index; + const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index; + + cam_scale_index = cam_scale_index << 1; + const int h = spatial_shape[cam_scale_index]; + const int w = spatial_shape[cam_scale_index + 1]; + + const float h_im = loc_h * h - 0.5; + const float w_im = loc_w * w - 0.5; + + atomicAdd( + output + anchor_index * num_embeds + channel_index, + bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight + ); +} + + +__global__ void deformable_aggregation_grad_kernel( + const int num_kernels, + const float* mc_ms_feat, + const int* spatial_shape, + const int* scale_start_index, + const float* sample_location, + const float* weights, + const float* grad_output, + float* grad_mc_ms_feat, + float* grad_sampling_location, + float* grad_weights, + int batch_size, + int num_cams, + int num_feat, + int num_embeds, + int num_scale, + int num_anchors, + int num_pts, + int num_groups +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= num_kernels) return; + + const int weights_ptr = idx / (num_embeds / num_groups); + const int channel_index = idx % num_embeds; + idx /= num_embeds; + const int scale_index = idx % num_scale; + idx /= num_scale; + + const int cam_index = idx % num_cams; + idx /= num_cams; + const int pts_index = idx % num_pts; + idx /= num_pts; + + int anchor_index = idx % num_anchors; + idx /= num_anchors; + const int batch_index = idx % batch_size; + idx /= batch_size; + + anchor_index = batch_index * num_anchors + anchor_index; + const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1; + + const float loc_w = sample_location[loc_offset]; + if (loc_w <= 0 || loc_w >= 1) return; + const float loc_h = sample_location[loc_offset + 1]; + if (loc_h <= 0 || loc_h >= 1) return; + + const float grad = grad_output[anchor_index*num_embeds + channel_index]; + + int cam_scale_index = cam_index * num_scale + scale_index; + const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index; + + cam_scale_index = cam_scale_index << 1; + const int h = spatial_shape[cam_scale_index]; + const int w = spatial_shape[cam_scale_index + 1]; + + const float h_im = loc_h * h - 0.5; + const float w_im = loc_w * w - 0.5; + + /* atomicAdd( */ + /* output + anchor_index * num_embeds + channel_index, */ + /* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */ + /* ); */ + const float weight = weights[weights_ptr]; + float *grad_weights_ptr = grad_weights + weights_ptr; + float *grad_location_ptr = grad_sampling_location + loc_offset; + bilinear_sampling_grad( + mc_ms_feat, weight, h, w, num_embeds, h_im, w_im, + value_offset, + grad, + grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr + ); +} + + +void deformable_aggregation( + float* output, + const float* mc_ms_feat, + const int* spatial_shape, + const int* scale_start_index, + const float* sample_location, + const float* weights, + int batch_size, + int num_cams, + int num_feat, + int num_embeds, + int num_scale, + int num_anchors, + int num_pts, + int num_groups +) { + const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale; + deformable_aggregation_kernel + <<<(int)ceil(((double)num_kernels/128)), 128>>>( + num_kernels, output, + mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights, + batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups + ); +} + + +void deformable_aggregation_grad( + const float* mc_ms_feat, + const int* spatial_shape, + const int* scale_start_index, + const float* sample_location, + const float* weights, + const float* grad_output, + float* grad_mc_ms_feat, + float* grad_sampling_location, + float* grad_weights, + int batch_size, + int num_cams, + int num_feat, + int num_embeds, + int num_scale, + int num_anchors, + int num_pts, + int num_groups +) { + const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale; + deformable_aggregation_grad_kernel + <<<(int)ceil(((double)num_kernels/128)), 128>>>( + num_kernels, + mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights, + grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights, + batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups + ); +} diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..e18254d --- /dev/null +++ b/requirement.txt @@ -0,0 +1,15 @@ +numpy==1.23.5 +mmcv_full==1.7.1 +mmdet==2.28.2 +urllib3==1.26.16 +pyquaternion==0.9.9 +nuscenes-devkit==1.1.10 +yapf==0.32 +tensorboard==2.14.0 +motmetrics==1.1.3 +pandas==1.1.5 +flash-attn==2.3.2 +opencv-python==4.8.1.78 +prettytable==3.7.0 +scikit-learn==1.3.0 +wandb==0.16.6 \ No newline at end of file diff --git a/resources/legend.png b/resources/legend.png new file mode 100644 index 0000000..c34e3b3 Binary files /dev/null and b/resources/legend.png differ diff --git a/resources/motion_planner.png b/resources/motion_planner.png new file mode 100644 index 0000000..a852fb6 Binary files /dev/null and b/resources/motion_planner.png differ diff --git a/resources/overview.png b/resources/overview.png new file mode 100644 index 0000000..cf15694 Binary files /dev/null and b/resources/overview.png differ diff --git a/resources/sdc_car.png b/resources/sdc_car.png new file mode 100644 index 0000000..df19cbc Binary files /dev/null and b/resources/sdc_car.png differ diff --git a/resources/sparse_perception.png b/resources/sparse_perception.png new file mode 100644 index 0000000..5696c76 Binary files /dev/null and b/resources/sparse_perception.png differ diff --git a/scripts/create_data.sh b/scripts/create_data.sh new file mode 100644 index 0000000..63f7e97 --- /dev/null +++ b/scripts/create_data.sh @@ -0,0 +1,16 @@ +export PYTHONPATH="$(dirname $0)/..":$PYTHONPATH + +python tools/data_converter/nuscenes_converter.py nuscenes \ + --root-path ./data/nuscenes \ + --canbus ./data/nuscenes \ + --out-dir ./data/infos/ \ + --extra-tag nuscenes \ + --version v1.0-mini + +python tools/data_converter/nuscenes_converter.py nuscenes \ + --root-path ./data/nuscenes \ + --canbus ./data/nuscenes \ + --out-dir ./data/infos/ \ + --extra-tag nuscenes \ + --version v1.0 + diff --git a/scripts/kmeans.sh b/scripts/kmeans.sh new file mode 100644 index 0000000..6b01371 --- /dev/null +++ b/scripts/kmeans.sh @@ -0,0 +1,4 @@ +python tools/kmeans/kmeans_det.py +python tools/kmeans/kmeans_map.py +python tools/kmeans/kmeans_motion.py +python tools/kmeans/kmeans_plan.py \ No newline at end of file diff --git a/scripts/test.sh b/scripts/test.sh new file mode 100644 index 0000000..c6861d2 --- /dev/null +++ b/scripts/test.sh @@ -0,0 +1,7 @@ +bash ./tools/dist_test.sh \ + projects/configs/sparsedrive_small_stage2.py \ + ckpt/sparsedrive_stage2.pth \ + 8 \ + --deterministic \ + --eval bbox + # --result_file ./work_dirs/sparsedrive_small_stage2/results.pkl \ No newline at end of file diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100644 index 0000000..c1f6171 --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,11 @@ +## stage1 +bash ./tools/dist_train.sh \ + projects/configs/sparsedrive_small_stage1.py \ + 8 \ + --deterministic + +## stage2 +bash ./tools/dist_train.sh \ + projects/configs/sparsedrive_small_stage2.py \ + 8 \ + --deterministic \ No newline at end of file diff --git a/scripts/visualize.sh b/scripts/visualize.sh new file mode 100644 index 0000000..dbf439c --- /dev/null +++ b/scripts/visualize.sh @@ -0,0 +1,4 @@ +export PYTHONPATH="$(dirname $0)/..":$PYTHONPATH +python tools/visualization/visualize.py \ + projects/configs/sparsedrive_small_stage2.py \ + --result-path work_dirs/sparsedrive_small_stage2/results.pkl \ No newline at end of file diff --git a/tools/analyze_occ_mask.py b/tools/analyze_occ_mask.py new file mode 100644 index 0000000..5efb399 --- /dev/null +++ b/tools/analyze_occ_mask.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +""" +Compute statistics for NuScenes objects that don't meet the num_lidar_pts > 0 +mask criteria (i.e. num_lidar_pts <= 0) used in NuScenes3DDataset.get_ann_info(). + +Usage: + python tools/analyze_num_lidar_pts_mask.py [ann_file] + python tools/analyze_num_lidar_pts_mask.py data/infos/nuscenes_infos_train.pkl + python tools/analyze_num_lidar_pts_mask.py data/infos/nuscenes_infos_val.pkl +""" + +import argparse +from collections import defaultdict + +import mmcv +import numpy as np + + +def main(): + parser = argparse.ArgumentParser( + description="Stats for objects with num_lidar_pts <= 0 (filtered out by dataset mask)" + ) + parser.add_argument( + "ann_file", + nargs="?", + default="data/infos/nuscenes_infos_train.pkl", + help="Path to NuScenes info pkl (default: data/infos/nuscenes_infos_train.pkl)", + ) + parser.add_argument( + "--load-interval", + type=int, + default=1, + help="Same as dataset load_interval (default: 1)", + ) + parser.add_argument( + "--by-class", + action="store_true", + help="Print per-class breakdown of filtered objects", + ) + args = parser.parse_args() + + print(f"Loading {args.ann_file} ...") + data = mmcv.load(args.ann_file, file_format="pkl") + infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) + infos = infos[:: args.load_interval] + print(f"Loaded {len(infos)} samples (load_interval={args.load_interval})\n") + + total_objs = 0 + filtered_out = 0 # num_lidar_pts <= 0 + samples_with_any_filtered = 0 + samples_with_all_filtered = 0 + # num_lidar_pts value distribution for filtered objects (0, -1, etc.) + filtered_pts_dist = defaultdict(int) + # per-class: (total, filtered_out) + class_total = defaultdict(int) + class_filtered = defaultdict(int) + + for info in infos: + num_pts = np.asarray(info["num_lidar_pts"]) + gt_names = info["gt_names"] + n = len(num_pts) + if n == 0: + continue + + total_objs += n + mask_kept = num_pts > 0 + mask_filtered = ~mask_kept + n_filtered = mask_filtered.sum() + filtered_out += n_filtered + + if n_filtered > 0: + samples_with_any_filtered += 1 + if n_filtered == n: + samples_with_all_filtered += 1 + + for v in num_pts[mask_filtered]: + filtered_pts_dist[int(v)] += 1 + + for i in range(n): + name = gt_names[i] if isinstance(gt_names[i], str) else gt_names[i].item() + class_total[name] += 1 + if num_pts[i] <= 0: + class_filtered[name] += 1 + + # Summary + print("=" * 60) + print("Summary (mask: num_lidar_pts > 0)") + print("=" * 60) + print(f" Total objects: {total_objs}") + print(f" Filtered (num_lidar_pts <= 0): {filtered_out}") + pct = 100.0 * filtered_out / total_objs + print(f" Filtered %: {pct:.2f}%") + + if args.by_class and class_total: + print("\n" + "=" * 60) + print("Per-class (filtered = num_lidar_pts <= 0)") + print("=" * 60) + for name in sorted(class_total.keys()): + tot = class_total[name] + flt = class_filtered[name] + pct = 100.0 * flt / tot if tot else 0 + print(f" {name:25s} total: {tot:6d} filtered: {flt:6d} ({pct:5.2f}%)") + + # ------------------------------------------------------------------------- + # Track-level stats: per track, how many annotations are not present at all + # in the infos (instance not labelled in that sample) after the first + # sample where the track appears. + # ------------------------------------------------------------------------- + if "instance_inds" in infos[0]: + # scene_token -> sorted list of sample_idx belonging to that scene + scene_to_sample_indices = defaultdict(list) + # (scene_token, instance_ind) -> list of sample_idx where instance is annotated + tracks = defaultdict(lambda: defaultdict(list)) + + for sample_idx, info in enumerate(infos): + scene_token = info["scene_token"] + scene_to_sample_indices[scene_token].append(sample_idx) + instance_inds = info.get("instance_inds", []) + if len(instance_inds) == 0: + continue + instance_inds = np.asarray(instance_inds) + for i in range(len(instance_inds)): + instance_ind = int(instance_inds[i]) + tracks[scene_token][instance_ind].append(sample_idx) + + for scene_token in scene_to_sample_indices: + scene_to_sample_indices[scene_token].sort() + + # For each track: span = [first_sample, last_sample]. Missing = number of + # samples in that span (in this scene) where the instance is not in the infos. + total_tracks = 0 + total_missing_not_labelled = 0 + total_missing_after_track_end = 0 + tracks_with_any_missing = 0 + missing_per_track_dist = defaultdict(int) + after_end_per_track_dist = defaultdict(int) + + for scene_token, instances in tracks.items(): + scene_samples = scene_to_sample_indices[scene_token] + for instance_ind, appearances in instances.items(): + if len(appearances) == 0: + continue + appearances = sorted(appearances) + first_sample = appearances[0] + last_sample = appearances[-1] + # Samples in this scene that fall in the track span + span_samples = [s for s in scene_samples if first_sample <= s <= last_sample] + expected_frames = len(span_samples) + present_frames = len(appearances) + missing = expected_frames - present_frames + + # Samples in this scene after the last labelled sample for this track + after_end = sum(1 for s in scene_samples if s > last_sample) + total_missing_after_track_end += after_end + after_end_per_track_dist[after_end] += 1 + + total_tracks += 1 + total_missing_not_labelled += missing + if missing > 0: + tracks_with_any_missing += 1 + missing_per_track_dist[missing] += 1 + + print("\n" + "=" * 60) + print("Track stats (missing = not present/labelled in infos at all)") + print("=" * 60) + print(f" Total tracks: {total_tracks}") + print(f" Total annotations missing mid track span (not in infos): {total_missing_not_labelled}") + print(f" Total annotations missing after track end (not in infos): {total_missing_after_track_end}") + print("\n Distribution of # missing per track mid span (not in infos):") + for k in sorted(missing_per_track_dist.keys()): + n_tracks = missing_per_track_dist[k] + print(f" {k:3d} missing: {n_tracks:6d} tracks") + print("\n Distribution of # missing after track end per track:") + for k in sorted(after_end_per_track_dist.keys()): + n_tracks = after_end_per_track_dist[k] + print(f" {k:3d} after end: {n_tracks:6d} tracks") + else: + print("\n (Skipping track stats: no 'instance_inds' in infos)") + + print() + + +if __name__ == "__main__": + main() diff --git a/tools/benchmark.py b/tools/benchmark.py new file mode 100644 index 0000000..e8b3511 --- /dev/null +++ b/tools/benchmark.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import time +import torch +from mmcv import Config +from mmcv.parallel import MMDataParallel +from mmcv.runner import load_checkpoint, wrap_fp16_model +import sys +sys.path.append('.') +from projects.mmdet3d_plugin.datasets.builder import build_dataloader +from projects.mmdet3d_plugin.datasets import custom_build_dataset +from mmdet.models import build_detector +from mmcv.cnn.utils.flops_counter import add_flops_counting_methods +from mmcv.parallel import scatter + + +def parse_args(): + parser = argparse.ArgumentParser(description='MMDet benchmark a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', default=None, help='checkpoint file') + parser.add_argument('--samples', default=1000, help='samples to benchmark') + parser.add_argument( + '--log-interval', default=50, help='interval of logging') + parser.add_argument( + '--fuse-conv-bn', + action='store_true', + help='Whether to fuse conv and bn, this will slightly increase' + 'the inference speed') + args = parser.parse_args() + return args + + +def get_max_memory(model): + device = getattr(model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([mem / (1024 * 1024)], + dtype=torch.int, + device=device) + return mem_mb.item() + + +def main(): + args = parse_args() + get_flops_params(args) + get_mem_fps(args) + +def get_mem_fps(args): + cfg = Config.fromfile(args.config) + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + cfg.model.pretrained = None + cfg.data.test.test_mode = True + + # build the dataloader + # TODO: support multiple images per gpu (only minor changes are needed) + print(cfg.data.test) + dataset = custom_build_dataset(cfg.data.test) + data_loader = build_dataloader( + dataset, + samples_per_gpu=1, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=False, + shuffle=False) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + if args.checkpoint is not None: + load_checkpoint(model, args.checkpoint, map_location='cpu') + # if args.fuse_conv_bn: + # model = fuse_module(model) + + model = MMDataParallel(model, device_ids=[0]) + + model.eval() + + # the first several iterations may be very slow so skip them + num_warmup = 5 + pure_inf_time = 0 + + # benchmark with several samples and take the average + max_memory = 0 + for i, data in enumerate(data_loader): + # torch.cuda.synchronize() + with torch.no_grad(): + start_time = time.perf_counter() + model(return_loss=False, rescale=True, **data) + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + max_memory = max(max_memory, get_max_memory(model)) + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % args.log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print(f'Done image [{i + 1:<3}/ {args.samples}], ' + f'fps: {fps:.1f} img / s, ' + f"gpu mem: {max_memory} M") + + if (i + 1) == args.samples: + pure_inf_time += elapsed + fps = (i + 1 - num_warmup) / pure_inf_time + print(f'Overall fps: {fps:.1f} img / s') + break + + +def get_flops_params(args): + gpu_id = 0 + cfg = Config.fromfile(args.config) + dataset = custom_build_dataset(cfg.data.val) + dataloader = build_dataloader( + dataset, + samples_per_gpu=1, + workers_per_gpu=0, + dist=False, + shuffle=False, + ) + data_iter = dataloader.__iter__() + data = next(data_iter) + data = scatter(data, [gpu_id])[0] + + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + wrap_fp16_model(model) + if args.checkpoint is not None: + load_checkpoint(model, args.checkpoint, map_location='cpu') + model = model.cuda(gpu_id) + model.eval() + + bilinear_flops = 11 + num_key_pts_det = ( + cfg.model["head"]['det_head']["deformable_model"]["kps_generator"]["num_learnable_pts"] + + len(cfg.model["head"]['det_head']["deformable_model"]["kps_generator"]["fix_scale"]) + ) + deformable_agg_flops_det = ( + cfg.num_decoder + * cfg.embed_dims + * cfg.num_levels + * cfg.model["head"]['det_head']["instance_bank"]["num_anchor"] + * cfg.model["head"]['det_head']["deformable_model"]["num_cams"] + * num_key_pts_det + * bilinear_flops + ) + num_key_pts_map = ( + cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["num_learnable_pts"] + + len(cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["fix_height"]) + ) * cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["num_sample"] + deformable_agg_flops_map = ( + cfg.num_decoder + * cfg.embed_dims + * cfg.num_levels + * cfg.model["head"]['map_head']["instance_bank"]["num_anchor"] + * cfg.model["head"]['map_head']["deformable_model"]["num_cams"] + * num_key_pts_map + * bilinear_flops + ) + deformable_agg_flops = deformable_agg_flops_det + deformable_agg_flops_map + + for module in ["total", "img_backbone", "img_neck", "head"]: + if module != "total": + flops_model = add_flops_counting_methods(getattr(model, module)) + else: + flops_model = add_flops_counting_methods(model) + flops_model.eval() + flops_model.start_flops_count() + + if module == "img_backbone": + flops_model(data["img"].flatten(0, 1)) + elif module == "img_neck": + flops_model(model.img_backbone(data["img"].flatten(0, 1))) + elif module == "head": + flops_model(model.extract_feat(data["img"], metas=data), data) + else: + flops_model(**data) + flops_count, params_count = flops_model.compute_average_flops_cost() + flops_count *= flops_model.__batch_counter__ + flops_model.stop_flops_count() + if module == "head" or module == "total": + flops_count += deformable_agg_flops + if module == "total": + total_flops = flops_count + total_params = params_count + print( + f"{module:<13} complexity: " + f"FLOPs={flops_count/ 10.**9:>8.4f} G / {flops_count/total_flops*100:>6.2f}%, " + f"Params={params_count/10**6:>8.4f} M / {params_count/total_params*100:>6.2f}%." + ) + +if __name__ == '__main__': + main() diff --git a/tools/data_converter/__init__.py b/tools/data_converter/__init__.py new file mode 100755 index 0000000..ef101fe --- /dev/null +++ b/tools/data_converter/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tools/data_converter/nuscenes_converter.py b/tools/data_converter/nuscenes_converter.py new file mode 100755 index 0000000..f7d346d --- /dev/null +++ b/tools/data_converter/nuscenes_converter.py @@ -0,0 +1,602 @@ +import os +import math +import copy +import argparse +from os import path as osp +from collections import OrderedDict +from typing import List, Tuple, Union + +import numpy as np +from pyquaternion import Quaternion +from shapely.geometry import MultiPoint, box + +import mmcv + +from nuscenes.nuscenes import NuScenes +from nuscenes.can_bus.can_bus_api import NuScenesCanBus +from nuscenes.utils.geometry_utils import transform_matrix +from nuscenes.utils.data_classes import Box +from nuscenes.utils.geometry_utils import view_points +from nuscenes.prediction import PredictHelper, convert_local_coords_to_global + +from projects.mmdet3d_plugin.datasets.map_utils.nuscmap_extractor import NuscMapExtractor + +NameMapping = { + "movable_object.barrier": "barrier", + "vehicle.bicycle": "bicycle", + "vehicle.bus.bendy": "bus", + "vehicle.bus.rigid": "bus", + "vehicle.car": "car", + "vehicle.construction": "construction_vehicle", + "vehicle.motorcycle": "motorcycle", + "human.pedestrian.adult": "pedestrian", + "human.pedestrian.child": "pedestrian", + "human.pedestrian.construction_worker": "pedestrian", + "human.pedestrian.police_officer": "pedestrian", + "movable_object.trafficcone": "traffic_cone", + "vehicle.trailer": "trailer", + "vehicle.truck": "truck", +} + +def quart_to_rpy(qua): + x, y, z, w = qua + roll = math.atan2(2 * (w * x + y * z), 1 - 2 * (x * x + y * y)) + pitch = math.asin(2 * (w * y - x * z)) + yaw = math.atan2(2 * (w * z + x * y), 1 - 2 * (z * z + y * y)) + return roll, pitch, yaw + +def locate_message(utimes, utime): + i = np.searchsorted(utimes, utime) + if i == len(utimes) or (i > 0 and utime - utimes[i-1] < utimes[i] - utime): + i -= 1 + return i + +def geom2anno(map_geoms): + MAP_CLASSES = ( + 'ped_crossing', + 'divider', + 'boundary', + ) + vectors = {} + for cls, geom_list in map_geoms.items(): + if cls in MAP_CLASSES: + label = MAP_CLASSES.index(cls) + vectors[label] = [] + for geom in geom_list: + line = np.array(geom.coords) + vectors[label].append(line) + return vectors + +def create_nuscenes_infos(root_path, + out_path, + can_bus_root_path, + info_prefix, + version='v1.0-trainval', + max_sweeps=10, + roi_size=(30, 60),): + """Create info file of nuscene dataset. + + Given the raw data, generate its related info file in pkl format. + + Args: + root_path (str): Path of the data root. + info_prefix (str): Prefix of the info file to be generated. + version (str): Version of the data. + Default: 'v1.0-trainval' + max_sweeps (int): Max number of sweeps. + Default: 10 + """ + print(version, root_path) + nusc = NuScenes(version=version, dataroot=root_path, verbose=True) + nusc_map_extractor = NuscMapExtractor(root_path, roi_size) + nusc_can_bus = NuScenesCanBus(dataroot=can_bus_root_path) + from nuscenes.utils import splits + available_vers = ['v1.0-trainval', 'v1.0-test', 'v1.0-mini'] + assert version in available_vers + if version == 'v1.0-trainval': + train_scenes = splits.train + val_scenes = splits.val + elif version == 'v1.0-test': + train_scenes = splits.test + val_scenes = [] + elif version == 'v1.0-mini': + train_scenes = splits.mini_train + val_scenes = splits.mini_val + out_path = osp.join(out_path, 'mini') + else: + raise ValueError('unknown') + os.makedirs(out_path, exist_ok=True) + + # filter existing scenes. + available_scenes = get_available_scenes(nusc) + available_scene_names = [s['name'] for s in available_scenes] + train_scenes = list( + filter(lambda x: x in available_scene_names, train_scenes)) + val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes)) + train_scenes = set([ + available_scenes[available_scene_names.index(s)]['token'] + for s in train_scenes + ]) + val_scenes = set([ + available_scenes[available_scene_names.index(s)]['token'] + for s in val_scenes + ]) + + test = 'test' in version + if test: + print('test scene: {}'.format(len(train_scenes))) + else: + print('train scene: {}, val scene: {}'.format( + len(train_scenes), len(val_scenes))) + + train_nusc_infos, val_nusc_infos = _fill_trainval_infos( + nusc, nusc_map_extractor, nusc_can_bus, train_scenes, val_scenes, test, max_sweeps=max_sweeps) + + metadata = dict(version=version) + if test: + print('test sample: {}'.format(len(train_nusc_infos))) + data = dict(infos=train_nusc_infos, metadata=metadata) + info_path = osp.join(out_path, + '{}_infos_test.pkl'.format(info_prefix)) + mmcv.dump(data, info_path) + else: + print('train sample: {}, val sample: {}'.format( + len(train_nusc_infos), len(val_nusc_infos))) + data = dict(infos=train_nusc_infos, metadata=metadata) + info_path = osp.join(out_path, + '{}_infos_train.pkl'.format(info_prefix)) + mmcv.dump(data, info_path) + data['infos'] = val_nusc_infos + info_val_path = osp.join(out_path, + '{}_infos_val.pkl'.format(info_prefix)) + mmcv.dump(data, info_val_path) + +def get_available_scenes(nusc): + """Get available scenes from the input nuscenes class. + + Given the raw data, get the information of available scenes for + further info generation. + + Args: + nusc (class): Dataset class in the nuScenes dataset. + + Returns: + available_scenes (list[dict]): List of basic information for the + available scenes. + """ + available_scenes = [] + print('total scene num: {}'.format(len(nusc.scene))) + for scene in nusc.scene: + scene_token = scene['token'] + scene_rec = nusc.get('scene', scene_token) + sample_rec = nusc.get('sample', scene_rec['first_sample_token']) + sd_rec = nusc.get('sample_data', sample_rec['data']['LIDAR_TOP']) + has_more_frames = True + scene_not_exist = False + while has_more_frames: + lidar_path, boxes, _ = nusc.get_sample_data(sd_rec['token']) + lidar_path = str(lidar_path) + if os.getcwd() in lidar_path: + # path from lyftdataset is absolute path + lidar_path = lidar_path.split(f'{os.getcwd()}/')[-1] + # relative path + if not mmcv.is_filepath(lidar_path): + scene_not_exist = True + break + else: + break + if scene_not_exist: + continue + available_scenes.append(scene) + print('exist scene num: {}'.format(len(available_scenes))) + return available_scenes + +def _fill_trainval_infos(nusc, + nusc_map_extractor, + nusc_can_bus, + train_scenes, + val_scenes, + test=False, + max_sweeps=10, + fut_ts=12, + ego_fut_ts=6): + """Generate the train/val infos from the raw data. + + Args: + nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset. + train_scenes (list[str]): Basic information of training scenes. + val_scenes (list[str]): Basic information of validation scenes. + test (bool): Whether use the test mode. In the test mode, no + annotations can be accessed. Default: False. + max_sweeps (int): Max number of sweeps. Default: 10. + + Returns: + tuple[list[dict]]: Information of training set and validation set + that will be saved to the info file. + """ + train_nusc_infos = [] + val_nusc_infos = [] + cat2idx = {} + for idx, dic in enumerate(nusc.category): + cat2idx[dic['name']] = idx + + predict_helper = PredictHelper(nusc) + for sample in mmcv.track_iter_progress(nusc.sample): + map_location = nusc.get('log', nusc.get('scene', sample['scene_token'])['log_token'])['location'] + lidar_token = sample['data']['LIDAR_TOP'] + sd_rec = nusc.get('sample_data', lidar_token) + cs_record = nusc.get('calibrated_sensor', + sd_rec['calibrated_sensor_token']) + pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token']) + lidar_path, boxes, _ = nusc.get_sample_data(lidar_token) + mmcv.check_file_exist(lidar_path) + + info = { + 'lidar_path': lidar_path, + 'token': sample['token'], + 'sweeps': [], + 'cams': dict(), + 'scene_token': sample['scene_token'], + 'lidar2ego_translation': cs_record['translation'], + 'lidar2ego_rotation': cs_record['rotation'], + 'ego2global_translation': pose_record['translation'], + 'ego2global_rotation': pose_record['rotation'], + 'timestamp': sample['timestamp'], + 'map_location': map_location, + } + + l2e_r = info['lidar2ego_rotation'] + l2e_t = info['lidar2ego_translation'] + e2g_r = info['ego2global_rotation'] + e2g_t = info['ego2global_translation'] + l2e_r_mat = Quaternion(l2e_r).rotation_matrix + e2g_r_mat = Quaternion(e2g_r).rotation_matrix + + # extract map annos + lidar2ego = np.eye(4) + lidar2ego[:3, :3] = Quaternion( + info["lidar2ego_rotation"] + ).rotation_matrix + lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"]) + ego2global = np.eye(4) + ego2global[:3, :3] = Quaternion( + info["ego2global_rotation"] + ).rotation_matrix + ego2global[:3, 3] = np.array(info["ego2global_translation"]) + lidar2global = ego2global @ lidar2ego + + translation = list(lidar2global[:3, 3]) + rotation = list(Quaternion(matrix=lidar2global).q) + map_geoms = nusc_map_extractor.get_map_geom(map_location, translation, rotation) + map_annos = geom2anno(map_geoms) + info['map_annos'] = map_annos + + # obtain 6 image's information per frame + camera_types = [ + 'CAM_FRONT', + 'CAM_FRONT_RIGHT', + 'CAM_FRONT_LEFT', + 'CAM_BACK', + 'CAM_BACK_LEFT', + 'CAM_BACK_RIGHT', + ] + for cam in camera_types: + cam_token = sample['data'][cam] + cam_path, _, cam_intrinsic = nusc.get_sample_data(cam_token) + cam_info = obtain_sensor2top(nusc, cam_token, l2e_t, l2e_r_mat, + e2g_t, e2g_r_mat, cam) + cam_info.update(cam_intrinsic=cam_intrinsic) + info['cams'].update({cam: cam_info}) + + # obtain sweeps for a single key-frame + sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP']) + sweeps = [] + while len(sweeps) < max_sweeps: + if not sd_rec['prev'] == '': + sweep = obtain_sensor2top(nusc, sd_rec['prev'], l2e_t, + l2e_r_mat, e2g_t, e2g_r_mat, 'lidar') + sweeps.append(sweep) + sd_rec = nusc.get('sample_data', sd_rec['prev']) + else: + break + info['sweeps'] = sweeps + # obtain annotation + if not test: + # object detection annos: boxes (locs, dims, yaw, velocity), names and valid flags + annotations = [ + nusc.get('sample_annotation', token) + for token in sample['anns'] + ] + locs = np.array([b.center for b in boxes]).reshape(-1, 3) + dims = np.array([b.wlh for b in boxes]).reshape(-1, 3) + rots = np.array([b.orientation.yaw_pitch_roll[0] + for b in boxes]).reshape(-1, 1) + velocity = np.array( + [nusc.box_velocity(token)[:2] for token in sample['anns']]) + # convert velo from global to lidar + for i in range(len(boxes)): + velo = np.array([*velocity[i], 0.0]) + velo = velo @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv( + l2e_r_mat).T + velocity[i] = velo[:2] + names = [b.name for b in boxes] + for i in range(len(names)): + if names[i] in NameMapping: + names[i] = NameMapping[names[i]] + names = np.array(names) + valid_flag = np.array( + [(anno['num_lidar_pts'] + anno['num_radar_pts']) > 0 + for anno in annotations], + dtype=bool).reshape(-1) ## TODO update valid flag for tracking + # we need to convert box size to + # the format of our lidar coordinate system + # which is x_size, y_size, z_size (corresponding to l, w, h) + gt_boxes = np.concatenate([locs, dims[:, [1, 0, 2]], rots], axis=1) + assert len(gt_boxes) == len( + annotations), f'{len(gt_boxes)}, {len(annotations)}' + + # object tracking annos: instance_ids + instance_inds = [nusc.getind('instance', anno['instance_token']) + for anno in annotations] + + # motion prediction annos: future trajectories offset in lidar frame and valid mask + num_box = len(boxes) + gt_fut_trajs = np.zeros((num_box, fut_ts, 2)) + gt_fut_masks = np.zeros((num_box, fut_ts)) + for i, anno in enumerate(annotations): + instance_token = anno['instance_token'] + fut_traj_local = predict_helper.get_future_for_agent( + instance_token, + sample['token'], + seconds=fut_ts/2, + in_agent_frame=True + ) + if fut_traj_local.shape[0] > 0: + box = boxes[i] + trans = box.center + rot = Quaternion(matrix=box.rotation_matrix) + fut_traj_scene = convert_local_coords_to_global(fut_traj_local, trans, rot) + valid_step = fut_traj_scene.shape[0] + gt_fut_trajs[i, 0] = fut_traj_scene[0] - box.center[:2] + gt_fut_trajs[i, 1:valid_step] = fut_traj_scene[1:] - fut_traj_scene[:-1] + gt_fut_masks[i, :valid_step] = 1 + + # motion planning annos: future trajectories offset in lidar frame and valid mask + ego_fut_trajs = np.zeros((ego_fut_ts + 1, 3)) + ego_fut_masks = np.zeros((ego_fut_ts + 1)) + sample_cur = sample + ego_status = get_ego_status(nusc, nusc_can_bus, sample_cur) + for i in range(ego_fut_ts + 1): + pose_mat = get_global_sensor_pose(sample_cur, nusc) + ego_fut_trajs[i] = pose_mat[:3, 3] + ego_fut_masks[i] = 1 + if sample_cur['next'] == '': + ego_fut_trajs[i+1:] = ego_fut_trajs[i] + break + else: + sample_cur = nusc.get('sample', sample_cur['next']) + # global to ego + ego_fut_trajs = ego_fut_trajs - np.array(pose_record['translation']) + rot_mat = Quaternion(pose_record['rotation']).inverse.rotation_matrix + ego_fut_trajs = np.dot(rot_mat, ego_fut_trajs.T).T + # ego to lidar + ego_fut_trajs = ego_fut_trajs - np.array(cs_record['translation']) + rot_mat = Quaternion(cs_record['rotation']).inverse.rotation_matrix + ego_fut_trajs = np.dot(rot_mat, ego_fut_trajs.T).T + # drive command according to final fut step offset + if ego_fut_trajs[-1][0] >= 2: + command = np.array([1, 0, 0]) # Turn Right + elif ego_fut_trajs[-1][0] <= -2: + command = np.array([0, 1, 0]) # Turn Left + else: + command = np.array([0, 0, 1]) # Go Straight + # get offset + ego_fut_trajs = ego_fut_trajs[1:] - ego_fut_trajs[:-1] + + info['gt_boxes'] = gt_boxes + info['gt_names'] = names + info['gt_velocity'] = velocity.reshape(-1, 2) + info['num_lidar_pts'] = np.array( + [a['num_lidar_pts'] for a in annotations]) + info['num_radar_pts'] = np.array( + [a['num_radar_pts'] for a in annotations]) + info['valid_flag'] = valid_flag + info['instance_inds'] = instance_inds + info['gt_agent_fut_trajs'] = gt_fut_trajs.astype(np.float32) + info['gt_agent_fut_masks'] = gt_fut_masks.astype(np.float32) + info['gt_ego_fut_trajs'] = ego_fut_trajs[:, :2].astype(np.float32) + info['gt_ego_fut_masks'] = ego_fut_masks[1:].astype(np.float32) + info['gt_ego_fut_cmd'] = command.astype(np.float32) + info['ego_status'] = ego_status + + if sample['scene_token'] in train_scenes: + train_nusc_infos.append(info) + else: + val_nusc_infos.append(info) + + return train_nusc_infos, val_nusc_infos + +def get_ego_status(nusc, nusc_can_bus, sample): + ego_status = [] + ref_scene = nusc.get("scene", sample['scene_token']) + try: + pose_msgs = nusc_can_bus.get_messages(ref_scene['name'],'pose') + steer_msgs = nusc_can_bus.get_messages(ref_scene['name'], 'steeranglefeedback') + pose_uts = [msg['utime'] for msg in pose_msgs] + steer_uts = [msg['utime'] for msg in steer_msgs] + ref_utime = sample['timestamp'] + pose_index = locate_message(pose_uts, ref_utime) + pose_data = pose_msgs[pose_index] + steer_index = locate_message(steer_uts, ref_utime) + steer_data = steer_msgs[steer_index] + ego_status.extend(pose_data["accel"]) # acceleration in ego vehicle frame, m/s/s + ego_status.extend(pose_data["rotation_rate"]) # angular velocity in ego vehicle frame, rad/s + ego_status.extend(pose_data["vel"]) # velocity in ego vehicle frame, m/s + ego_status.append(steer_data["value"]) # steering angle, positive: left turn, negative: right turn + except: + ego_status = [0] * 10 + + return np.array(ego_status).astype(np.float32) + +def get_global_sensor_pose(rec, nusc): + lidar_sample_data = nusc.get('sample_data', rec['data']['LIDAR_TOP']) + + pose_record = nusc.get("ego_pose", lidar_sample_data["ego_pose_token"]) + cs_record = nusc.get("calibrated_sensor", lidar_sample_data["calibrated_sensor_token"]) + + ego2global = transform_matrix(pose_record["translation"], Quaternion(pose_record["rotation"]), inverse=False) + sensor2ego = transform_matrix(cs_record["translation"], Quaternion(cs_record["rotation"]), inverse=False) + pose = ego2global.dot(sensor2ego) + + return pose + +def obtain_sensor2top(nusc, + sensor_token, + l2e_t, + l2e_r_mat, + e2g_t, + e2g_r_mat, + sensor_type='lidar'): + """Obtain the info with RT matric from general sensor to Top LiDAR. + + Args: + nusc (class): Dataset class in the nuScenes dataset. + sensor_token (str): Sample data token corresponding to the + specific sensor type. + l2e_t (np.ndarray): Translation from lidar to ego in shape (1, 3). + l2e_r_mat (np.ndarray): Rotation matrix from lidar to ego + in shape (3, 3). + e2g_t (np.ndarray): Translation from ego to global in shape (1, 3). + e2g_r_mat (np.ndarray): Rotation matrix from ego to global + in shape (3, 3). + sensor_type (str): Sensor to calibrate. Default: 'lidar'. + + Returns: + sweep (dict): Sweep information after transformation. + """ + sd_rec = nusc.get('sample_data', sensor_token) + cs_record = nusc.get('calibrated_sensor', + sd_rec['calibrated_sensor_token']) + pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token']) + data_path = str(nusc.get_sample_data_path(sd_rec['token'])) + if os.getcwd() in data_path: # path from lyftdataset is absolute path + data_path = data_path.split(f'{os.getcwd()}/')[-1] # relative path + sweep = { + 'data_path': data_path, + 'type': sensor_type, + 'sample_data_token': sd_rec['token'], + 'sensor2ego_translation': cs_record['translation'], + 'sensor2ego_rotation': cs_record['rotation'], + 'ego2global_translation': pose_record['translation'], + 'ego2global_rotation': pose_record['rotation'], + 'timestamp': sd_rec['timestamp'] + } + + l2e_r_s = sweep['sensor2ego_rotation'] + l2e_t_s = sweep['sensor2ego_translation'] + e2g_r_s = sweep['ego2global_rotation'] + e2g_t_s = sweep['ego2global_translation'] + + # obtain the RT from sensor to Top LiDAR + # sweep->ego->global->ego'->lidar + l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix + e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix + R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ ( + np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) + T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ ( + np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T) + T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T + ) + l2e_t @ np.linalg.inv(l2e_r_mat).T + sweep['sensor2lidar_rotation'] = R.T # points @ R.T + T + sweep['sensor2lidar_translation'] = T + return sweep + +def nuscenes_data_prep(root_path, + can_bus_root_path, + info_prefix, + version, + dataset_name, + out_dir, + max_sweeps=10): + """Prepare data related to nuScenes dataset. + + Related data consists of '.pkl' files recording basic infos, + 2D annotations and groundtruth database. + + Args: + root_path (str): Path of dataset root. + info_prefix (str): The prefix of info filenames. + version (str): Dataset version. + dataset_name (str): The dataset class name. + out_dir (str): Output directory of the groundtruth database info. + max_sweeps (int): Number of input consecutive frames. Default: 10 + """ + create_nuscenes_infos( + root_path, out_dir, can_bus_root_path, info_prefix, version=version, max_sweeps=max_sweeps) + + +parser = argparse.ArgumentParser(description='Data converter arg parser') +parser.add_argument('dataset', metavar='kitti', help='name of the dataset') +parser.add_argument( + '--root-path', + type=str, + default='./data/kitti', + help='specify the root path of dataset') +parser.add_argument( + '--canbus', + type=str, + default='./data', + help='specify the root path of nuScenes canbus') +parser.add_argument( + '--version', + type=str, + default='v1.0', + required=False, + help='specify the dataset version, no need for kitti') +parser.add_argument( + '--max-sweeps', + type=int, + default=10, + required=False, + help='specify sweeps of lidar per example') +parser.add_argument( + '--out-dir', + type=str, + default='./data/kitti', + required='False', + help='name of info pkl') +parser.add_argument('--extra-tag', type=str, default='kitti') +parser.add_argument( + '--workers', type=int, default=4, help='number of threads to be used') +args = parser.parse_args() + +if __name__ == '__main__': + if args.dataset == 'nuscenes' and args.version != 'v1.0-mini': + train_version = f'{args.version}-trainval' + nuscenes_data_prep( + root_path=args.root_path, + can_bus_root_path=args.canbus, + info_prefix=args.extra_tag, + version=train_version, + dataset_name='NuScenesDataset', + out_dir=args.out_dir, + max_sweeps=args.max_sweeps) + test_version = f'{args.version}-test' + nuscenes_data_prep( + root_path=args.root_path, + can_bus_root_path=args.canbus, + info_prefix=args.extra_tag, + version=test_version, + dataset_name='NuScenesDataset', + out_dir=args.out_dir, + max_sweeps=args.max_sweeps) + elif args.dataset == 'nuscenes' and args.version == 'v1.0-mini': + train_version = f'{args.version}' + nuscenes_data_prep( + root_path=args.root_path, + can_bus_root_path=args.canbus, + info_prefix=args.extra_tag, + version=train_version, + dataset_name='NuScenesDataset', + out_dir=args.out_dir, + max_sweeps=args.max_sweeps) diff --git a/tools/data_converter/nuscenes_occlusion_converter.py b/tools/data_converter/nuscenes_occlusion_converter.py new file mode 100644 index 0000000..9bbc4b4 --- /dev/null +++ b/tools/data_converter/nuscenes_occlusion_converter.py @@ -0,0 +1,1777 @@ +#!/usr/bin/env python3 +"""Fill annotation gaps in nuScenes info pkls and visualise the results. + +Two processing passes: + +1. **Gap interpolation** (CATR): fills frames between two consecutive + observations of the same instance using Constant Acceleration and Turn + Rate kinematics with a linear endpoint position correction. + +2. **Forward extrapolation** (CV): projects each track's last observed + state forward by up to ``--max-extrap-frames`` frames using Constant + Velocity. + +New flags in every sample info dict: + ``is_interpolated`` (bool array) — CATR gap-fill boxes + ``is_extrapolated`` (bool array) — CV tail-extrapolation boxes +Directly observed boxes have both flags False. + +Sub-commands +------------ +convert Run the annotation pipeline and save a new pkl. +visualize Draw BEV track plots for example scenes from an existing pkl. + +Examples +-------- + python tools/data_converter/nuscenes_occlusion_converter.py convert \\ + --input data/infos/nuscenes_infos_val.pkl \\ + --output data/infos/nuscenes_infos_val_occ.pkl + + python tools/data_converter/nuscenes_occlusion_converter.py visualize \\ + --pkl data/infos/nuscenes_infos_val_occ.pkl \\ + --num-scenes 6 --output-dir viz/ +""" + +import pickle +import argparse +import os +import numpy as np +from collections import defaultdict + + +# =========================================================================== +# CATR kinematics helpers +# =========================================================================== + +def _normalize_angle(angle): + return (angle + np.pi) % (2.0 * np.pi) - np.pi + + +def _integrate_catr(x0, y0, yaw0, v0, omega, a, t, clamp_v=False): + """Vectorised Riemann-sum integration of CATR kinematics from 0 to t. + + clamp_v : if True, speed is clamped to [0, inf) so a decelerating object + stops rather than reversing (use for forward extrapolation). + """ + n = max(200, int(abs(t) * 400)) + dt = t / n + s = np.arange(n) * dt + theta = yaw0 + omega * s + v = v0 + a * s + if clamp_v: + v = np.maximum(v, 0.0) + x = x0 + float(np.sum(v * np.cos(theta))) * dt + y = y0 + float(np.sum(v * np.sin(theta))) * dt + return x, y + + +def catr_interpolate(state0, state1, t, T): + """Interpolate between two kinematic states using the CATR model. + + Derives constant turn rate (omega = dtheta/T) and linear acceleration + (a = dv/T) from the endpoint states, integrates to get intermediate + positions, then applies a linear endpoint correction so the trajectory + passes through both endpoints exactly. + """ + x0, y0, z0, l, w, h, yaw0, vx0, vy0 = state0 + x1, y1, z1, _, _, _, yaw1, vx1, vy1 = state1 + alpha = t / T + v0 = 0.0 if np.isnan(vx0) or np.isnan(vy0) else float(np.hypot(vx0, vy0)) + v1 = 0.0 if np.isnan(vx1) or np.isnan(vy1) else float(np.hypot(vx1, vy1)) + omega = _normalize_angle(yaw1 - yaw0) / T + a = (v1 - v0) / T + x_t, y_t = _integrate_catr(x0, y0, yaw0, v0, omega, a, t) + x_T, y_T = _integrate_catr(x0, y0, yaw0, v0, omega, a, T) + x_interp = x_t + alpha * (x1 - x_T) + y_interp = y_t + alpha * (y1 - y_T) + z_interp = z0 + alpha * (z1 - z0) + yaw_t = _normalize_angle(yaw0 + omega * t) + v_t = max(0.0, v0 + a * t) + return (float(x_interp), float(y_interp), float(z_interp), + float(l), float(w), float(h), float(yaw_t), + float(v_t * np.cos(yaw_t)), float(v_t * np.sin(yaw_t))) + + +# =========================================================================== +# Coordinate transform helpers +# =========================================================================== + +def _lidar2global_RT(info): + """Return (R 3×3, t 3) for the lidar → global transform of this sample.""" + from pyquaternion import Quaternion + R_l2e = Quaternion(info['lidar2ego_rotation']).rotation_matrix + t_l2e = np.array(info['lidar2ego_translation']) + R_e2g = Quaternion(info['ego2global_rotation']).rotation_matrix + t_e2g = np.array(info['ego2global_translation']) + R = R_e2g @ R_l2e + t = R_e2g @ t_l2e + t_e2g + return R, t + + +def _box_to_global(box7, vel2, info): + """Convert a lidar-frame box + lidar-frame velocity to a global-frame state. + + Parameters + ---------- + box7 : array-like (7,) [x, y, z, l, w, h, yaw] in lidar frame + vel2 : array-like (2,) [vx, vy] in lidar frame + (the nuScenes converter rotates box_velocity into lidar) + info : sample info dict + + Returns + ------- + 9-tuple (x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g) all in global frame + """ + R, t = _lidar2global_RT(info) + p_g = R @ np.array([box7[0], box7[1], box7[2]]) + t + yaw_offset = np.arctan2(R[1, 0], R[0, 0]) + yaw_g = _normalize_angle(float(box7[6]) + yaw_offset) + # Rotate velocity from lidar frame to global frame (translation has no + # effect on vectors, only the rotation part of R matters). + if np.isnan(vel2[0]) or np.isnan(vel2[1]): + vx_g, vy_g = 0.0, 0.0 + else: + v_g = R @ np.array([float(vel2[0]), float(vel2[1]), 0.0]) + vx_g, vy_g = float(v_g[0]), float(v_g[1]) + return (float(p_g[0]), float(p_g[1]), float(p_g[2]), + float(box7[3]), float(box7[4]), float(box7[5]), + yaw_g, vx_g, vy_g) + + +def _state_global_to_lidar(state_g, info): + """Convert a global-frame state tuple to the lidar frame of *info*. + + Both position/yaw and velocity are rotated into the target lidar frame, + matching the gt_velocity convention used by the nuScenes converter. + + Parameters + ---------- + state_g : 9-tuple (x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g) + info : target sample info dict + + Returns + ------- + 9-tuple (x, y, z, l, w, h, yaw, vx, vy) — all in target lidar frame. + """ + x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g = state_g + R, t = _lidar2global_RT(info) + p_l = R.T @ (np.array([x_g, y_g, z_g]) - t) + yaw_offset = np.arctan2(R[1, 0], R[0, 0]) + yaw_l = _normalize_angle(yaw_g - yaw_offset) + # Rotate velocity from global frame back into the target lidar frame. + v_l = R.T @ np.array([vx_g, vy_g, 0.0]) + return (float(p_l[0]), float(p_l[1]), float(p_l[2]), + float(l), float(w), float(h), float(yaw_l), float(v_l[0]), float(v_l[1])) + + +def _global_xy_to_lidar_xy(positions_g, info): + """Transform (N,2) global XY positions into the lidar frame of *info*. + + Uses the full 3-D lidar→global rotation (R.T) applied to the XY plane; + the Z component is ignored since we only need XY deltas. + """ + R, t = _lidar2global_RT(info) + pos3 = np.zeros((len(positions_g), 3)) + pos3[:, :2] = positions_g + return ((pos3 - t) @ R)[:, :2] # R.T @ (p - t) written as row-vec form + + +# =========================================================================== +# ML-prediction helpers +# =========================================================================== + +def _load_ml_predictions(pred_path): + """Load a UniTraj inference NPZ and build an instance-token index. + + Returns + ------- + preds : ndarray (N, 6, 60, 2) agent-centric absolute XY offsets from + the agent's last-observed position, in a + frame whose +X axis aligns with the agent's + heading. + worlds : ndarray (N, 10) center_objects_world — [:2] global XY + origin, [6] global heading (yaw). + lookup : dict str → list[int] instance token → row indices in preds. + """ + npz = np.load(pred_path, allow_pickle=True) + preds = npz['predictions'] # (N, 6, 60, 2) + meta = npz['metadata'].item() # flat dict + worlds = meta['center_objects_world'].astype(np.float64) # (N, 10) + ids = meta['center_objects_id'] # (N,) str + lookup = defaultdict(list) + for i, inst in enumerate(ids): + lookup[str(inst)].append(i) + return preds, worlds, lookup + + +def _find_pred_index(inst_id, x_g, y_g, worlds, lookup, pos_tol=2.0): + """Return the row index in *worlds*/*preds* for this instance near (x_g, y_g). + + Matches by instance token first, then picks the candidate whose stored + global XY is within *pos_tol* metres. Returns None if no match. + """ + cands = lookup.get(str(inst_id), []) + if not cands: + return None + dists = [np.hypot(worlds[ci, 0] - x_g, worlds[ci, 1] - y_g) for ci in cands] + best = int(np.argmin(dists)) + return cands[best] if dists[best] < pos_tol else None + + +def _pred_global_xy(ml_preds_i, world, step): + """Convert agent-centric prediction at *step* to global (x, y). + + ml_preds_i : (6, 60, 2) — mode 0 is used (probabilities not saved yet). + world : (10,) — world[:2] = global origin, world[6] = global heading. + step : 0-based UniTraj step (0 → 0.1 s ahead of prediction moment). + Returns (x_g, y_g) or None when step is out of [0, 59]. + """ + if not (0 <= step < ml_preds_i.shape[1]): + return None + pred_ac = ml_preds_i[0, step] # mode 0, shape (2,) + theta = float(world[6]) + c, s = np.cos(theta), np.sin(theta) + return (float(c * pred_ac[0] - s * pred_ac[1] + world[0]), + float(s * pred_ac[0] + c * pred_ac[1] + world[1])) + + +def _dt_to_step(dt_seconds, unitraj_dt): + """Seconds offset from prediction moment → 0-based UniTraj step index. + + Step 0 corresponds to *unitraj_dt* seconds ahead. + Use unitraj_dt=0.1 for 10 Hz models, 0.5 for 2 Hz models. + """ + return int(round(dt_seconds / unitraj_dt)) - 1 + + +def _build_instance_token_map(nuscenes_dataroot, version='v1.0-trainval'): + """Return a dict mapping integer nuScenes instance indices → token strings. + + ``nusc.getind('instance', token)`` returns the 0-based position of a record + in the instance table, which is the same as its index in instance.json. + This map lets us convert the integer ``inst_ind`` stored by ForeSight back + to the 32-char token string stored in the UniTraj NPZ. + """ + import json + path = os.path.join(nuscenes_dataroot, version, 'instance.json') + with open(path) as fh: + instances = json.load(fh) + return {i: rec['token'] for i, rec in enumerate(instances)} + + +# =========================================================================== +# Motion-model fitting helper +# =========================================================================== + +def _fit_motion_model(appearances, infos, min_history, + ca_noise_thr, ca_max_thr, ca_consistency_thr, + omega_noise_thr, omega_max_thr, omega_consistency_thr): + """Fit scalar acceleration and turn-rate from the trailing consecutive + observation run of a track. + + Only frames that are immediately consecutive in the scene (frame_pos + difference == 1) are used, starting from the last observation and + working backwards. Heading is estimated from velocity direction when + speed > 0.5 m/s (more reliable than box yaw for moving objects). + + Returns + ------- + (a_fit, omega_fit) : + a_fit — scalar acceleration (m/s²); 0.0 when not reliably fitted. + omega_fit — turn rate (rad/s); 0.0 when not reliably fitted. + """ + # Build the longest trailing run of consecutive scene frames. + run = [appearances[-1]] + for i in range(len(appearances) - 2, -1, -1): + if appearances[i + 1][0] - appearances[i][0] == 1: + run.insert(0, appearances[i]) + else: + break + if len(run) < min_history: + return 0.0, 0.0 + + # Collect (timestamp, speed, box yaw) per frame. + # Box yaw is directly annotated; velocity is inferred from position differences + # and is therefore noisier — use box yaw for turn-rate estimation. + states = [] + for _, gi, bi in run: + info = infos[gi] + sg = _box_to_global(info['gt_boxes'][bi], info['gt_velocity'][bi], info) + spd = float(np.hypot(sg[7], sg[8])) + states.append((info['timestamp'] * 1e-6, spd, float(sg[6]))) + + # Per-interval acceleration and turn-rate. + accels, omegas = [], [] + for i in range(len(states) - 1): + t0, v0, h0 = states[i] + t1, v1, h1 = states[i + 1] + dt = t1 - t0 + if dt <= 0: + continue + accels.append((v1 - v0) / dt) + omegas.append(_normalize_angle(h1 - h0) / dt) + + if not accels: + return 0.0, 0.0 + + a_mean = float(np.mean(accels)) + a_std = float(np.std(accels)) if len(accels) > 1 else 0.0 + om_mean = float(np.mean(omegas)) + om_std = float(np.std(omegas)) if len(omegas) > 1 else 0.0 + + a_fit = (a_mean + if a_std < ca_consistency_thr + and ca_noise_thr <= abs(a_mean) <= ca_max_thr + else 0.0) + omega_fit = (om_mean + if om_std < omega_consistency_thr + and omega_noise_thr <= abs(om_mean) <= omega_max_thr + else 0.0) + return a_fit, omega_fit + + +# =========================================================================== +# Shared annotation-append helper +# =========================================================================== + +def _append_annotation(info, x, y, z, l, w, h, yaw, vx, vy, + inst_ind, class_name, is_interp, is_extrap, + fut_trajs=None, fut_masks=None): + new_box = np.array([[x, y, z, l, w, h, yaw]], dtype=np.float32) + new_vel = np.array([[vx, vy]], dtype=np.float32) + fut_ts = info['gt_agent_fut_trajs'].shape[1] + if fut_trajs is None: + fut_trajs = np.zeros((1, fut_ts, 2), dtype=np.float32) + else: + fut_trajs = np.asarray(fut_trajs, dtype=np.float32).reshape(1, fut_ts, 2) + if fut_masks is None: + fut_masks = np.zeros((1, fut_ts), dtype=np.float32) + else: + fut_masks = np.asarray(fut_masks, dtype=np.float32).reshape(1, fut_ts) + info['gt_boxes'] = np.concatenate([info['gt_boxes'], new_box], axis=0) + info['gt_velocity'] = np.concatenate([info['gt_velocity'], new_vel], axis=0) + info['gt_agent_fut_trajs'] = np.concatenate([info['gt_agent_fut_trajs'], fut_trajs], axis=0) + info['gt_agent_fut_masks'] = np.concatenate([info['gt_agent_fut_masks'], fut_masks], axis=0) + info['gt_names'] = np.append(info['gt_names'], class_name) + info['valid_flag'] = np.append(info['valid_flag'], False) + info['num_lidar_pts'] = np.append(info['num_lidar_pts'], 0) + info['num_radar_pts'] = np.append(info['num_radar_pts'], 0) + info['is_interpolated'] = np.append(info['is_interpolated'], is_interp) + info['is_extrapolated'] = np.append(info['is_extrapolated'], is_extrap) + info['instance_inds'].append(inst_ind) + + +# =========================================================================== +# Convert sub-command +# =========================================================================== + +def cmd_convert(args): + print(f'Loading {args.input} ...') + with open(args.input, 'rb') as f: + data = pickle.load(f) + infos = data['infos'] + print(f' {len(infos)} samples loaded.') + + for info in infos: + n = len(info['gt_boxes']) + info['is_interpolated'] = np.zeros(n, dtype=bool) + info['is_extrapolated'] = np.zeros(n, dtype=bool) + + # Load ML predictions if provided + ml_preds = ml_worlds = ml_lookup = None + inst_token_map = {} # int index → 32-char token string + if getattr(args, 'predictions', None): + print(f'Loading ML predictions from {args.predictions} ...') + ml_preds, ml_worlds, ml_lookup = _load_ml_predictions(args.predictions) + print(f' {len(ml_preds)} prediction entries loaded.') + dataroot = getattr(args, 'nuscenes_dataroot', None) + if dataroot: + print(f'Building instance-token map from {dataroot} ...') + inst_token_map = _build_instance_token_map(dataroot) + print(f' {len(inst_token_map)} instance records indexed.') + else: + print('WARNING: --nuscenes-dataroot not set; ML predictions will ' + 'not be matched (all extrapolations fall back to CV).') + + scene_to_indices = defaultdict(list) + for gi, info in enumerate(infos): + scene_to_indices[info['scene_token']].append(gi) + for sc in scene_to_indices: + scene_to_indices[sc].sort(key=lambda i: infos[i]['timestamp']) + + # ------------------------------------------------------------------ + # Pass 1: gap interpolation (CATR) + # ------------------------------------------------------------------ + total_interpolated = 0 + for _, global_indices in scene_to_indices.items(): + tracks = defaultdict(list) + for frame_pos, gi in enumerate(global_indices): + for bi, inst_ind in enumerate(infos[gi]['instance_inds']): + tracks[inst_ind].append((frame_pos, gi, bi)) + + for inst_ind, appearances in tracks.items(): + for k in range(len(appearances) - 1): + fp0, gi0, bi0 = appearances[k] + fp1, gi1, bi1 = appearances[k + 1] + if fp1 - fp0 <= 1: + continue + info0, info1 = infos[gi0], infos[gi1] + # Convert both endpoints to global frame so CATR operates + # in a single consistent coordinate system. + state0_g = _box_to_global( + info0['gt_boxes'][bi0], info0['gt_velocity'][bi0], info0) + state1_g = _box_to_global( + info1['gt_boxes'][bi1], info1['gt_velocity'][bi1], info1) + class_name = info0['gt_names'][bi0] + t0_s = info0['timestamp'] * 1e-6 + T = info1['timestamp'] * 1e-6 - t0_s + for gap_fp in range(fp0 + 1, fp1): + gap_gi = global_indices[gap_fp] + gap_info = infos[gap_gi] + t = gap_info['timestamp'] * 1e-6 - t0_s + result_g = catr_interpolate(state0_g, state1_g, t, T) + # Convert the global-frame result back to the target frame's + # lidar coordinates before storing. + x, y, z, l, w, h, yaw, vx, vy = _state_global_to_lidar( + result_g, gap_info) + # Build future trajectory: CATR positions for future scene + # frames that still fall within the known gap [t0_s, t0_s+T]. + fut_ts_n = gap_info['gt_agent_fut_trajs'].shape[1] + fut_positions_g = [np.array([result_g[0], result_g[1]])] + for k in range(1, fut_ts_n + 1): + fut_fp = gap_fp + k + if fut_fp >= len(global_indices): + break + t_fut_rel = infos[global_indices[fut_fp]]['timestamp'] * 1e-6 - t0_s + if t_fut_rel > T: + break + fut_g = catr_interpolate(state0_g, state1_g, t_fut_rel, T) + fut_positions_g.append(np.array([fut_g[0], fut_g[1]])) + all_l = _global_xy_to_lidar_xy(np.array(fut_positions_g), gap_info) + fut_trajs = np.zeros((fut_ts_n, 2), dtype=np.float32) + fut_masks = np.zeros(fut_ts_n, dtype=np.float32) + for k in range(1, len(fut_positions_g)): + fut_trajs[k - 1] = all_l[k] - all_l[k - 1] + fut_masks[k - 1] = 1.0 + _append_annotation(gap_info, x, y, z, l, w, h, yaw, vx, vy, + inst_ind, class_name, + is_interp=True, is_extrap=False, + fut_trajs=fut_trajs, fut_masks=fut_masks) + total_interpolated += 1 + + # ------------------------------------------------------------------ + # Pass 2: forward extrapolation (ML prediction or CV fallback) + # ------------------------------------------------------------------ + # NuScenes gt_names that map to MetaDrive VEHICLE / PEDESTRIAN / CYCLIST and + # are therefore included in UniTraj inference predictions. + # Source: scenarionet/converter/nuscenes/type.py + # VEHICLE_TYPE → MetaDriveType.VEHICLE → object_type 1 + # HUMAN_TYPE → MetaDriveType.PEDESTRIAN → object_type 2 + # BICYCLE_TYPE → MetaDriveType.CYCLIST → object_type 3 + # Other categories (barrier, traffic_cone, movable_object.*, animal …) + # are not predicted — CV fallback is correct for those. + _ML_CLASSES = { + # VEHICLE_TYPE + 'car', 'truck', 'bus', 'trailer', 'construction_vehicle', + 'vehicle.emergency.ambulance', 'vehicle.emergency.police', + # HUMAN_TYPE + 'pedestrian', + 'human.pedestrian.stroller', 'human.pedestrian.personal_mobility', + 'human.pedestrian.construction_worker', 'human.pedestrian.police_officer', + # BICYCLE_TYPE + 'motorcycle', 'bicycle', + } + # Motorized road users eligible for CA/CTR motion-model fitting. + _MOTORIZED_CLASSES = { + 'car', 'truck', 'bus', 'trailer', 'construction_vehicle', + 'vehicle.emergency.ambulance', 'vehicle.emergency.police', + 'motorcycle', + } + # UniTraj VEHICLE type — stationary instances (total displacement < 2 m) + # are filtered out during training, so ML predictions for these are + # unreliable when the agent is near-stationary. PEDESTRIAN and CYCLIST + # have no displacement filter and keep their ML predictions when slow. + _VEHICLE_CLASSES = { + 'car', 'truck', 'bus', 'trailer', 'construction_vehicle', + 'vehicle.emergency.ambulance', 'vehicle.emergency.police', + } + total_extrap = 0 + total_extrap_ml = 0 + total_extrap_ml_shifted = 0 # ML matches via earlier-frame fallback + cv_by_class: dict = defaultdict(int) # CV fallback counts broken down by class + cv_model_counter = defaultdict(int) # fallback model breakdown (CP/CV/CA/CTR/CATR) + cv_no_entry: int = 0 # CV because instance has zero NPZ rows (unrecoverable) + cv_dist_fail: int = 0 # CV because closest NPZ row exceeds pos_tol (fixable?) + cv_sanity_fail: int = 0 # CV because ML 0.5 s point diverges too far from CV + if not args.no_extrapolate: + for _, global_indices in scene_to_indices.items(): + n_frames = len(global_indices) + present = [set(infos[gi]['instance_inds']) for gi in global_indices] + + orig_tracks = defaultdict(list) + for frame_pos, gi in enumerate(global_indices): + info = infos[gi] + for bi, inst_ind in enumerate(info['instance_inds']): + if not info['is_interpolated'][bi] and not info['is_extrapolated'][bi]: + orig_tracks[inst_ind].append((frame_pos, gi, bi)) + + for inst_ind, appearances in orig_tracks.items(): + last_fp, last_gi, last_bi = appearances[-1] + if last_fp >= n_frames - 1: + continue + info_ref = infos[last_gi] + class_name = info_ref['gt_names'][last_bi] + # Convert last observation to global frame. + x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g = _box_to_global( + info_ref['gt_boxes'][last_bi], + info_ref['gt_velocity'][last_bi], + info_ref) + t_ref = info_ref['timestamp'] * 1e-6 + + # ML-prediction lookup: find the NPZ row whose stored global + # position is closest to (x_g, y_g) for this instance. + # ForeSight stores integer indices; the NPZ stores token strings + # — convert via inst_token_map before querying ml_lookup. + # + # Primary match: prediction moment == last observed frame. + # Fallback match: use an earlier observed frame whose position + # aligns with a sliding-window prediction moment. This lets + # late-scene instances (last observed at frames 140-195 in 10 Hz + # terms, beyond the sf=115 coverage) still receive ML-based + # extrapolation. pred_time_offset (seconds) is then added to + # every dt so that step indices remain relative to the NPZ + # prediction moment rather than the last observation. + pred_row = None + pred_time_offset = 0.0 # seconds from NPZ pred-moment to last_obs + if ml_lookup is not None and class_name in _ML_CLASSES: + inst_token = inst_token_map.get(inst_ind) + if inst_token is not None: + pred_row = _find_pred_index( + inst_token, x_g, y_g, ml_worlds, ml_lookup) + + if pred_row is None and len(appearances) > 1: + # Try earlier observations, most-recent first. + for earlier_fp, earlier_gi, earlier_bi in reversed(appearances[:-1]): + info_e = infos[earlier_gi] + state_e = _box_to_global( + info_e['gt_boxes'][earlier_bi], + info_e['gt_velocity'][earlier_bi], + info_e) + row = _find_pred_index( + inst_token, state_e[0], state_e[1], + ml_worlds, ml_lookup) + if row is not None: + dt_off = t_ref - info_e['timestamp'] * 1e-6 + # dt_off must be positive (earlier frame) and + # within the 6 s prediction horizon. + if 0 < dt_off < 6.0: + pred_row = row + pred_time_offset = dt_off + break + + # Sanity-check: reject the ML prediction if its 0.5 s point + # diverges too far from the CV prediction at the same time. + # Threshold = 0.5 + alpha * speed (m) — tighter for slow movers. + if pred_row is not None: + step_check = _dt_to_step(pred_time_offset + 0.5, args.unitraj_dt) + if 0 <= step_check < 60: + xy_ml_check = _pred_global_xy( + ml_preds[pred_row], ml_worlds[pred_row], step_check) + if xy_ml_check is not None: + speed = float(np.hypot(vx_g, vy_g)) + x_cv_check = x_g + vx_g * 0.5 + y_cv_check = y_g + vy_g * 0.5 + tol = 0.5 + 0.3 * speed + if np.hypot(xy_ml_check[0] - x_cv_check, + xy_ml_check[1] - y_cv_check) > tol: + pred_row = None + cv_sanity_fail += 1 + + # Diagnostic: classify why ML lookup failed for ML classes. + if pred_row is None and class_name in _ML_CLASSES and ml_lookup is not None: + inst_token_d = inst_token_map.get(inst_ind) + if inst_token_d is not None: + all_cands = ml_lookup.get(str(inst_token_d), []) + if all_cands: + cv_dist_fail += 1 # has NPZ entries but all exceeded pos_tol + else: + cv_no_entry += 1 # never appeared as center agent in inference + + # Per-instance motion state (used by all fallback models below). + # Use box yaw directly — it is annotated by humans and more + # reliable than velocity-direction heading, which is inferred + # from position differences. + v0 = float(np.hypot(vx_g, vy_g)) + drive_yaw_g = yaw_g + + # Override ML for stationary vehicles: UniTraj filters out + # VEHICLE instances with total displacement < 2 m, so ML + # predictions for slow vehicles are unreliable. Pedestrians + # and cyclists have no displacement filter and keep their ML + # predictions even when near-stationary. + if (pred_row is not None + and v0 < args.cv_stationary_thr + and class_name in _VEHICLE_CLASSES): + pred_row = None + + # Fit CA/CTR model from history for motorized classes without ML. + ca_accel = 0.0 + ca_omega = 0.0 + if pred_row is None and class_name in _MOTORIZED_CLASSES: + ca_accel, ca_omega = _fit_motion_model( + appearances, infos, args.min_ca_history, + args.ca_noise_thr, args.ca_max_thr, args.ca_consistency_thr, + args.omega_noise_thr, args.omega_max_thr, + args.omega_consistency_thr) + + # Pre-compute ML handover state. Step 59 (the final UniTraj + # step) snaps close to zero for most predictions and is + # discarded. Steps 0–58 are reliable and used as-is. + # _ml_last : last ML step we use (= 58, skipping only 59) + _span = 5 + _ml_last = 58 + ml_end_x, ml_end_y = x_g, y_g + ml_end_vx, ml_end_vy = vx_g, vy_g + ml_end_yaw = drive_yaw_g + ml_end_dt_offset = 0.0 + if pred_row is not None: + xy_end = _pred_global_xy( + ml_preds[pred_row], ml_worlds[pred_row], _ml_last) + if xy_end is not None: + ml_end_x, ml_end_y = xy_end + # Backward-difference velocity at _ml_last. + xy_bend = _pred_global_xy( + ml_preds[pred_row], ml_worlds[pred_row], _ml_last - _span) + if xy_bend is not None: + ml_end_vx = (ml_end_x - xy_bend[0]) / (_span * args.unitraj_dt) + ml_end_vy = (ml_end_y - xy_bend[1]) / (_span * args.unitraj_dt) + # Yaw: central difference centred on _ml_last. + xy_pend = _pred_global_xy( + ml_preds[pred_row], ml_worlds[pred_row], _ml_last - 2 * _span) + if xy_pend is not None: + dx_end = ml_end_x - xy_pend[0] + dy_end = ml_end_y - xy_pend[1] + if np.hypot(dx_end, dy_end) > 0.05: + ml_end_yaw = float(np.arctan2(dy_end, dx_end)) + ml_end_dt_offset = (_ml_last + 1) * args.unitraj_dt - pred_time_offset + + frames_added = 0 + for fp in range(last_fp + 1, n_frames): + if inst_ind in present[fp] or frames_added >= args.max_extrap_frames: + break + gi = global_indices[fp] + dt = infos[gi]['timestamp'] * 1e-6 - t_ref + + # --- Position, heading, velocity --- + # pred_time_offset shifts dt so steps are relative to the + # NPZ prediction moment (which may precede last_obs). + step = (_dt_to_step(pred_time_offset + dt, args.unitraj_dt) + if pred_row is not None else -1) + xy = (_pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step) + if (pred_row is not None and 0 <= step <= _ml_last) else None) + + if xy is not None: + x_ep, y_ep = xy + # Heading: clamp to nearest step where central difference + # is valid (both neighbours in [0, 59]) so boundary steps + # borrow the adjacent central-diff yaw instead of using + # noisy one-sided differences. + _span = 5 + step_cd = max(_span, min(step, 59 - _span)) + xy_n_cd = _pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step_cd + _span) + xy_p_cd = _pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step_cd - _span) + if xy_n_cd is not None and xy_p_cd is not None: + dx, dy = xy_n_cd[0] - xy_p_cd[0], xy_n_cd[1] - xy_p_cd[1] + else: + dx, dy = 0.0, 0.0 + yaw_ep = (float(np.arctan2(dy, dx)) + if np.hypot(dx, dy) > 0.05 else yaw_g) + # Velocity: same clamped central difference as yaw — + # reuse xy_n_cd / xy_p_cd already fetched above. + if xy_n_cd is not None and xy_p_cd is not None: + vx_ep = (xy_n_cd[0] - xy_p_cd[0]) / (2 * _span * args.unitraj_dt) + vy_ep = (xy_n_cd[1] - xy_p_cd[1]) / (2 * _span * args.unitraj_dt) + else: + vx_ep, vy_ep = vx_g, vy_g + state_g = (x_ep, y_ep, z_g, l, w, h, yaw_ep, vx_ep, vy_ep) + total_extrap_ml += 1 + if pred_time_offset > 0: + total_extrap_ml_shifted += 1 + else: + # Fallback hierarchy: CP → CATR/CA/CTR → CV. + # When ML was available but its horizon is exhausted, + # anchor from the ML step-59 endpoint to avoid a + # positional jump back to the original observation. + cv_by_class[class_name] += 1 + if pred_row is not None: + ref_x, ref_y = ml_end_x, ml_end_y + ref_vx, ref_vy = ml_end_vx, ml_end_vy + ref_yaw = ml_end_yaw + ref_v0 = float(np.hypot(ref_vx, ref_vy)) + ref_dt = dt - ml_end_dt_offset + else: + ref_x, ref_y = x_g, y_g + ref_vx, ref_vy = vx_g, vy_g + ref_yaw = drive_yaw_g + ref_v0 = v0 + ref_dt = dt + if ref_v0 < args.cv_stationary_thr: + model_key = 'CP' + state_g = (ref_x, ref_y, z_g, l, w, h, ref_yaw, 0.0, 0.0) + elif ca_accel != 0.0 or ca_omega != 0.0: + if ca_accel != 0.0 and ca_omega != 0.0: + model_key = 'CATR' + elif ca_accel != 0.0: + model_key = 'CA' + else: + model_key = 'CTR' + x_ca, y_ca = _integrate_catr( + ref_x, ref_y, ref_yaw, ref_v0, ca_omega, ca_accel, ref_dt, + clamp_v=True) + yaw_ca = _normalize_angle(ref_yaw + ca_omega * ref_dt) + v_ca = max(0.0, ref_v0 + ca_accel * ref_dt) + state_g = (x_ca, y_ca, z_g, l, w, h, yaw_ca, + v_ca * np.cos(yaw_ca), v_ca * np.sin(yaw_ca)) + else: + model_key = 'CV' + state_g = (ref_x + ref_vx*ref_dt, ref_y + ref_vy*ref_dt, z_g, + l, w, h, ref_yaw, ref_vx, ref_vy) + cv_model_counter[model_key] += 1 + + x, y, z, l_, w_, h_, yaw, vx, vy = _state_global_to_lidar( + state_g, infos[gi]) + + # --- Future trajectory --- + fut_ts_n = infos[gi]['gt_agent_fut_trajs'].shape[1] + fut_positions_g = [np.array([state_g[0], state_g[1]])] + for k in range(1, fut_ts_n + 1): + fut_fp = fp + k + if fut_fp >= n_frames: + break + dt_fut = infos[global_indices[fut_fp]]['timestamp'] * 1e-6 - t_ref + step_fut = (_dt_to_step(pred_time_offset + dt_fut, args.unitraj_dt) + if pred_row is not None else -1) + xy_fut = (_pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step_fut) + if (pred_row is not None and 0 <= step_fut <= _ml_last) else None) + if xy_fut is not None: + fut_positions_g.append(np.array([xy_fut[0], xy_fut[1]])) + else: + # Same ML-end anchoring as the main fallback above. + if pred_row is not None: + f_x, f_y = ml_end_x, ml_end_y + f_vx, f_vy = ml_end_vx, ml_end_vy + f_yaw = ml_end_yaw + f_v0 = float(np.hypot(f_vx, f_vy)) + f_dt = dt_fut - ml_end_dt_offset + else: + f_x, f_y = x_g, y_g + f_vx, f_vy = vx_g, vy_g + f_yaw = drive_yaw_g + f_v0 = v0 + f_dt = dt_fut + if f_v0 < args.cv_stationary_thr: + fut_positions_g.append(np.array([f_x, f_y])) + elif ca_accel != 0.0 or ca_omega != 0.0: + x_ca, y_ca = _integrate_catr( + f_x, f_y, f_yaw, f_v0, ca_omega, ca_accel, f_dt, + clamp_v=True) + fut_positions_g.append(np.array([x_ca, y_ca])) + else: + fut_positions_g.append( + np.array([f_x + f_vx*f_dt, f_y + f_vy*f_dt])) + all_l = _global_xy_to_lidar_xy(np.array(fut_positions_g), infos[gi]) + fut_trajs = np.zeros((fut_ts_n, 2), dtype=np.float32) + fut_masks = np.zeros(fut_ts_n, dtype=np.float32) + for k in range(1, len(fut_positions_g)): + fut_trajs[k - 1] = all_l[k] - all_l[k - 1] + fut_masks[k - 1] = 1.0 + _append_annotation(infos[gi], x, y, z, l_, w_, h_, yaw, + vx, vy, inst_ind, class_name, + is_interp=False, is_extrap=True, + fut_trajs=fut_trajs, fut_masks=fut_masks) + present[fp].add(inst_ind) + frames_added += 1 + total_extrap += 1 + + n_orig = sum(int(np.sum(~i['is_interpolated'] & ~i['is_extrapolated'])) for i in infos) + n_interp = sum(int(np.sum( i['is_interpolated'])) for i in infos) + n_extrap = sum(int(np.sum( i['is_extrapolated'])) for i in infos) + n_cv_non_ml = sum(v for k, v in cv_by_class.items() if k not in _ML_CLASSES) + n_cv_ml_cls = sum(v for k, v in cv_by_class.items() if k in _ML_CLASSES) + print(f'Original annotations : {n_orig}') + print(f'Interpolated (CATR) : {n_interp}') + print(f'Extrapolated (ML exact) : {total_extrap_ml - total_extrap_ml_shifted}') + print(f'Extrapolated (ML shifted) : {total_extrap_ml_shifted}') + print(f'Extrapolated (CV fallback) : {total_extrap - total_extrap_ml}') + print(f' of which non-ML classes : {n_cv_non_ml} ' + f'(traffic_cone/barrier/… — CV is correct here)') + print(f' of which ML classes : {n_cv_ml_cls} ' + f'(temporal gap or no NPZ entry)') + print(f' no NPZ entry (unrecoverable) : {cv_no_entry}') + print(f' pos_tol exceeded (fixable?) : {cv_dist_fail}') + print(f' sanity filter (0.5s/speed) : {cv_sanity_fail}') + if cv_model_counter: + mdl = sorted(cv_model_counter.items()) + print(f' fallback model breakdown : ' + + ' '.join(f'{k}:{v}' for k, v in mdl)) + if cv_by_class: + by_cls = sorted(cv_by_class.items(), key=lambda x: -x[1]) + print(f' CV breakdown by class : ' + + ' '.join(f'{k}:{v}' for k, v in by_cls)) + print(f'Total annotations : {n_orig + n_interp + n_extrap}') + + # ------------------------------------------------------------------ + # Ego-distance filtering (two-level): + # Level 2 — drop instances with no original obs within --max-dist. + # Level 1 — tail-trim: drop all frames of an instance AFTER the last + # frame where it appears within --max-dist. Using a tail-trim + # rather than an independent per-frame check ensures trajectories + # stay continuous; a per-frame check would punch holes in curved + # tracks that briefly exit and re-enter the range boundary, + # producing disconnected fragments 50+ m away. + # ------------------------------------------------------------------ + if args.max_dist > 0: + # Level-2: instances that have at least one original obs within range. + valid_insts = set() + for info in infos: + boxes = info['gt_boxes'] + is_i = info['is_interpolated'] + is_e = info['is_extrapolated'] + for bi, inst_ind in enumerate(info['instance_inds']): + if not is_i[bi] and not is_e[bi]: + if np.hypot(boxes[bi, 0], boxes[bi, 1]) <= args.max_dist: + valid_insts.add(inst_ind) + print(f' Ego-distance filter: {len(valid_insts):,} instances have ' + f'≥1 original obs within {args.max_dist} m.') + + # Level-1 pre-pass: for each valid instance record the timestamp of + # its last frame that is within max_dist. Timestamps are + # monotonically increasing within a scene, so all frames up to and + # including this timestamp are kept; frames after it are trimmed. + last_ts_within: dict = {} + for info in infos: + ts = info['timestamp'] + boxes = info['gt_boxes'] + for bi, inst_ind in enumerate(info['instance_inds']): + if inst_ind not in valid_insts: + continue + if np.hypot(boxes[bi, 0], boxes[bi, 1]) <= args.max_dist: + if ts > last_ts_within.get(inst_ind, -1): + last_ts_within[inst_ind] = ts + + # Level-1 apply: keep box if instance is valid AND the frame + # timestamp is at or before the last within-range timestamp. + n_before = sum(len(i['instance_inds']) for i in infos) + for info in infos: + ts = info['timestamp'] + boxes = info['gt_boxes'] + keep = np.array([ + (inst_ind in valid_insts and + ts <= last_ts_within.get(inst_ind, -1)) + for bi, inst_ind in enumerate(info['instance_inds']) + ], dtype=bool) + info['gt_boxes'] = info['gt_boxes'][keep] + info['gt_velocity'] = info['gt_velocity'][keep] + info['gt_names'] = info['gt_names'][keep] + info['valid_flag'] = info['valid_flag'][keep] + info['num_lidar_pts'] = info['num_lidar_pts'][keep] + info['num_radar_pts'] = info['num_radar_pts'][keep] + info['is_interpolated'] = info['is_interpolated'][keep] + info['is_extrapolated'] = info['is_extrapolated'][keep] + info['gt_agent_fut_trajs'] = info['gt_agent_fut_trajs'][keep] + info['gt_agent_fut_masks'] = info['gt_agent_fut_masks'][keep] + info['instance_inds'] = [ + inst for inst, k in zip(info['instance_inds'], keep) if k] + n_after = sum(len(i['instance_inds']) for i in infos) + print(f' Removed {n_before - n_after:,} box entries ' + f'({n_before:,} → {n_after:,}).') + + print(f'Saving to {args.output} ...') + data['infos'] = infos + with open(args.output, 'wb') as f: + pickle.dump(data, f) + print('Done.') + + +# =========================================================================== +# Visualize sub-command +# =========================================================================== + +_C_OBS = '#1565c0' # observed, valid_flag=True (blue) +_C_INVALID = '#6a1b9a' # observed, valid_flag=False (purple) +_C_INTERP = '#e65100' # interpolated (CATR) (deep orange) +_C_EXTRAP = '#b71c1c' # extrapolated (dark red) +_C_EGO = '#2e7d32' # ego vehicle (dark green) +_C_BG = '#ffffff' # plot background (white) + + +def _ego_pose_global(info): + from pyquaternion import Quaternion + q = Quaternion(info['ego2global_rotation']) + pos = np.array(info['ego2global_translation'])[:2] + yaw = np.arctan2(2*(q.w*q.z + q.x*q.y), 1 - 2*(q.y**2 + q.z**2)) + return pos, yaw + + +def _draw_box(ax, cx, cy, length, width, yaw, color, + alpha_face=0.22, alpha_edge=0.85, lw=0.8, zorder=3): + import matplotlib.pyplot as plt + c, s = np.cos(yaw), np.sin(yaw) + R2 = np.array([[c, -s], [s, c]]) + half = np.array([[ length/2, width/2], + [-length/2, width/2], + [-length/2, -width/2], + [ length/2, -width/2]]) + corners = half @ R2.T + np.array([cx, cy]) + ax.add_patch(plt.Polygon(corners, closed=True, facecolor=color, edgecolor=color, + alpha=alpha_face, linewidth=lw, zorder=zorder)) + ax.add_patch(plt.Polygon(corners, closed=True, facecolor='none', edgecolor=color, + alpha=alpha_edge, linewidth=lw, zorder=zorder+1)) + + +def _visualize_scene(infos, gidxs, ax, title='', show_forecast=False, nusc=None): + import matplotlib.patches as mpatches + import matplotlib.lines as mlines + + inst_data = defaultdict(list) + ego_pos, ego_yaws = [], [] + + # Reference frame: first ego pose in the scene. + # All global-frame coordinates are transformed so that the first ego + # position is the origin and the first ego heading points along +X. + p0, yaw0 = _ego_pose_global(infos[gidxs[0]]) + c0, s0 = np.cos(yaw0), np.sin(yaw0) + R_inv = np.array([[c0, s0], [-s0, c0]]) # global → first-ego-frame rotation + + def to_ego(xy_global): + """Transform (N,2) or (2,) from global frame to first-ego frame.""" + return (np.asarray(xy_global) - p0) @ R_inv.T + + # Map background — drawn first so trajectories render on top. + # nusc.explorer.render_ego_centric_map uses flat vehicle coordinates for + # the reference sample, which matches our first-ego-frame (+X forward, + # +Y left) exactly. We pass a generous axes_limit so the full scene + # extent is covered; explicit axis limits are set later from data. + if nusc is not None: + try: + first_token = infos[gidxs[0]]['token'] + lidar_token = nusc.get('sample', first_token)['data']['LIDAR_TOP'] + nusc.explorer.render_ego_centric_map( + sample_data_token=lidar_token, axes_limit=200, ax=ax) + except Exception: + pass # map unavailable — continue without it + + for frame_idx, gi in enumerate(gidxs): + info = infos[gi] + R, t = _lidar2global_RT(info) + boxes = info['gt_boxes'] + is_i = info['is_interpolated'] + is_e = info['is_extrapolated'] + + if len(boxes): + xy_g = (boxes[:, :3] @ R.T + t)[:, :2] + yaw_g = boxes[:, 6] + np.arctan2(R[1, 0], R[0, 0]) + xy_e = to_ego(xy_g) + yaw_e = yaw_g - yaw0 + valid = info['valid_flag'] + for bi in range(len(boxes)): + if is_i[bi]: + color = _C_INTERP + elif is_e[bi]: + color = _C_EXTRAP + elif valid[bi]: + color = _C_OBS + else: + color = _C_INVALID + + # Reconstruct future waypoints in first-ego frame. + # Deltas are in the current sample's lidar frame; accumulate + # them to get absolute lidar positions, then transform to global + # and finally to first-ego frame. + fut_traj_raw = info['gt_agent_fut_trajs'][bi] # (T, 2) lidar deltas + fut_mask_raw = info['gt_agent_fut_masks'][bi] # (T,) + pos_l = np.array([boxes[bi, 0], boxes[bi, 1]], dtype=np.float64) + fut_pts_l = [] + for k in range(len(fut_traj_raw)): + if fut_mask_raw[k] < 0.5: + break + pos_l = pos_l + fut_traj_raw[k] + fut_pts_l.append(pos_l.copy()) + if fut_pts_l: + fpl = np.zeros((len(fut_pts_l), 3)) + fpl[:, :2] = np.array(fut_pts_l) + fut_pts_ego = to_ego((fpl @ R.T + t)[:, :2]) + else: + fut_pts_ego = np.zeros((0, 2)) + + inst_data[info['instance_inds'][bi]].append(( + frame_idx, + float(xy_e[bi, 0]), float(xy_e[bi, 1]), + float(boxes[bi, 3]), float(boxes[bi, 4]), + float(yaw_e[bi]), color, + fut_pts_ego, # index 7: (K,2) future waypoints in ego frame + )) + + pos, yaw = _ego_pose_global(info) + ego_pos.append(to_ego(pos)) + ego_yaws.append(yaw - yaw0) + + ego_pos = np.array(ego_pos) + + # Track lines — drawn for every consecutive pair in chronological order. + # Segments bridging a frame gap (diff > 1) are dashed to indicate the + # discontinuity; all other segments are solid. + for frames in inst_data.values(): + frames = sorted(frames, key=lambda f: f[0]) + for k in range(len(frames) - 1): + f0, f1 = frames[k], frames[k + 1] + gap = f1[0] - f0[0] + ax.plot([f0[1], f1[1]], [f0[2], f1[2]], + color=f0[6], linewidth=0.7, + alpha=0.35 if gap > 1 else 0.55, + linestyle='--' if gap > 1 else '-', + zorder=2, solid_capstyle='round') + + # Boxes — draw back-to-front so valid observed renders on top + _alpha_face = {_C_EXTRAP: 0.13, _C_INTERP: 0.22, _C_INVALID: 0.18, _C_OBS: 0.22} + _alpha_edge = {_C_EXTRAP: 0.55, _C_INTERP: 0.85, _C_INVALID: 0.70, _C_OBS: 0.85} + _lw = {_C_EXTRAP: 0.5, _C_INTERP: 0.8, _C_INVALID: 0.7, _C_OBS: 0.8} + for target_color in [_C_EXTRAP, _C_INTERP, _C_INVALID, _C_OBS]: + for frames in inst_data.values(): + for f in frames: + if f[6] != target_color: + continue + _draw_box(ax, f[1], f[2], f[3], f[4], f[5], color=f[6], + alpha_face=_alpha_face[f[6]], + alpha_edge=_alpha_edge[f[6]], + lw=_lw[f[6]]) + + # Forecast trajectories (one dotted line per box appearance with valid future steps) + if show_forecast: + for frames in inst_data.values(): + for f in frames: + fut_pts = f[7] + if len(fut_pts) == 0: + continue + all_pts = np.vstack([[[f[1], f[2]]], fut_pts]) + ax.plot(all_pts[:, 0], all_pts[:, 1], + color=f[6], alpha=0.55, linewidth=1.0, + linestyle=':', zorder=4, solid_capstyle='round') + ax.scatter(fut_pts[:, 0], fut_pts[:, 1], + c=f[6], s=5, alpha=0.65, zorder=5, linewidths=0) + + # Ego trajectory and boxes + ax.plot(ego_pos[:, 0], ego_pos[:, 1], + color='#333333', linewidth=1.5, linestyle='--', zorder=6, alpha=0.8) + ax.scatter(*ego_pos[0], color='#333333', s=30, zorder=8, marker='o') + ax.scatter(*ego_pos[-1], color='#333333', s=30, zorder=8, marker='x') + for pos, yaw in zip(ego_pos, ego_yaws): + _draw_box(ax, pos[0], pos[1], 4.08, 1.73, yaw, + color=_C_EGO, alpha_face=0.30, alpha_edge=0.9, lw=1.0, zorder=7) + + # Explicit axis limits from data so the map imshow (which calls set_xlim/ + # set_ylim internally) does not dictate the final view. + all_x = ([f[1] for v in inst_data.values() for f in v] + + list(ego_pos[:, 0])) + all_y = ([f[2] for v in inst_data.values() for f in v] + + list(ego_pos[:, 1])) + if all_x: + xspan = max(all_x) - min(all_x) + yspan = max(all_y) - min(all_y) + span = max(xspan, yspan, 20.0) + pad = span * 0.08 + cx = (max(all_x) + min(all_x)) / 2 + cy = (max(all_y) + min(all_y)) / 2 + ax.set_xlim(cx - span / 2 - pad, cx + span / 2 + pad) + ax.set_ylim(cy - span / 2 - pad, cy + span / 2 + pad) + ax.set_aspect('equal') + ax.set_facecolor(_C_BG) + ax.grid(True, color='#bbbbbb', alpha=0.4, linewidth=0.5) + ax.tick_params(colors='#444444', labelsize=7) + for spine in ax.spines.values(): + spine.set_color('#cccccc') + ax.set_xlabel('X (m)', color='#444444', fontsize=8) + ax.set_ylabel('Y (m)', color='#444444', fontsize=8) + n_valid = sum(sum(1 for f in v if f[6] == _C_OBS) for v in inst_data.values()) + n_invalid = sum(sum(1 for f in v if f[6] == _C_INVALID) for v in inst_data.values()) + n_interp = sum(sum(1 for f in v if f[6] == _C_INTERP) for v in inst_data.values()) + n_extrap = sum(sum(1 for f in v if f[6] == _C_EXTRAP) for v in inst_data.values()) + ax.set_title( + f'{title} | Visible {n_valid} Occluded {n_invalid} ' + f'Interpolated {n_interp} Extrapolated {n_extrap}', + color='#111111', fontsize=8, pad=4) + + leg_handles = [ + mpatches.Patch(color='#aaaaaa', label='Driveable Area'), + mpatches.Patch(color=_C_OBS, label='Visible'), + mpatches.Patch(color=_C_INVALID, label='Occluded'), + mpatches.Patch(color=_C_INTERP, label='Interpolated'), + mpatches.Patch(color=_C_EXTRAP, label='Extrapolated'), + mpatches.Patch(color=_C_EGO, label='Ego'), + ] + if show_forecast: + leg_handles.append(mlines.Line2D( + [], [], color='#333333', linestyle=':', linewidth=1.2, alpha=0.7, + label='Forecast')) + ax.legend(handles=leg_handles, loc='upper right', framealpha=0.85, + fontsize=6, labelcolor='#111111', facecolor='white', + edgecolor='#cccccc', ncol=1) + + +# =========================================================================== +# Single-track helpers (individual track plots) +# =========================================================================== + +def _collect_all_tracks(infos, all_scenes): + """Return a list of per-instance track dicts across all scenes. + + Uses the same first-ego-frame coordinate system as ``_visualize_scene``: + each scene is centred on its first ego pose so coordinates are directly + comparable to the scene-level BEV plots. + + Each frame tuple is ``(frame_idx, x_e, y_e, l, w, yaw_e, color)`` — + identical to the entries stored in ``inst_data`` inside ``_visualize_scene``. + ``cv_frames`` is a parallel list of ``(x_e, y_e)`` CV predictions for each + extrapolated frame, used by the mismatch plot. + """ + tracks = [] + for gidxs in all_scenes: + sc_tok = infos[gidxs[0]]['scene_token'] + + # Same reference frame as _visualize_scene: first ego pose of the scene. + p0, yaw0 = _ego_pose_global(infos[gidxs[0]]) + c0, s0 = np.cos(yaw0), np.sin(yaw0) + R_ego = np.array([[c0, s0], [-s0, c0]]) # global → first-ego rotation + + def to_ego(xy): + return (np.asarray(xy) - p0) @ R_ego.T + + # inst_frames: same format as inst_data in _visualize_scene. + # inst_meta: raw (gi, bi) lists for CV anchor computation. + inst_frames = defaultdict(list) + inst_meta = defaultdict(lambda: {'obs_raw': [], 'extrap_raw': []}) + + for frame_idx, gi in enumerate(gidxs): + info = infos[gi] + R, t = _lidar2global_RT(info) + boxes = info['gt_boxes'] + if not len(boxes): + continue + xy_g = (boxes[:, :3] @ R.T + t)[:, :2] # lidar → global + yaw_g = boxes[:, 6] + np.arctan2(R[1, 0], R[0, 0]) + xy_e = to_ego(xy_g) # global → first-ego + yaw_e = yaw_g - yaw0 + is_i = info['is_interpolated'] + is_e = info['is_extrapolated'] + valid = info['valid_flag'] + + for bi, inst_ind in enumerate(info['instance_inds']): + if is_i[bi]: + color = _C_INTERP + elif is_e[bi]: + color = _C_EXTRAP + inst_meta[inst_ind]['extrap_raw'].append((gi, bi)) + elif valid[bi]: + color = _C_OBS + inst_meta[inst_ind]['obs_raw'].append((gi, bi)) + else: + color = _C_INVALID + inst_meta[inst_ind]['obs_raw'].append((gi, bi)) + + inst_frames[inst_ind].append(( + frame_idx, + float(xy_e[bi, 0]), float(xy_e[bi, 1]), + float(boxes[bi, 3]), float(boxes[bi, 4]), + float(yaw_e[bi]), color, + )) + + for inst_ind, raw_frames in inst_frames.items(): + frames_sorted = sorted(raw_frames, key=lambda f: f[0]) + n_interp = sum(1 for f in frames_sorted if f[6] == _C_INTERP) + n_extrap = sum(1 for f in frames_sorted if f[6] == _C_EXTRAP) + + def _path_len(color): + pts = [(f[1], f[2]) for f in frames_sorted if f[6] == color] + return sum( + np.hypot(pts[i+1][0] - pts[i][0], pts[i+1][1] - pts[i][1]) + for i in range(len(pts) - 1) + ) + extrap_dist = _path_len(_C_EXTRAP) + interp_dist = _path_len(_C_INTERP) + + # CV mismatch: project last-observed velocity forward in ego frame. + # Rigid transform preserves distances, so scores are identical to + # computing in global frame. + max_cv_ml_diff = 0.0 + cv_frames = [] # (x_e, y_e) per extrapolated frame + meta = inst_meta[inst_ind] + if meta['obs_raw'] and meta['extrap_raw']: + # Use the last observed frame BEFORE the extrapolation starts. + # obs_raw may contain re-appearance frames that come after the + # extrap gap; using those as the anchor gives negative dt and + # a completely wrong CV projection. + first_extrap_gi = meta['extrap_raw'][0][0] + obs_before = [(g, b) for g, b in meta['obs_raw'] + if g < first_extrap_gi] + if not obs_before: + obs_before = meta['obs_raw'] + anc_gi, anc_bi = obs_before[-1] + anc_info = infos[anc_gi] + state_g = _box_to_global( + anc_info['gt_boxes'][anc_bi], + anc_info['gt_velocity'][anc_bi], + anc_info) + xy0_e = to_ego(np.array([state_g[0], state_g[1]])) + x0_e, y0_e = float(xy0_e[0]), float(xy0_e[1]) + v_e = R_ego @ np.array([state_g[7], state_g[8]]) + vx_e, vy_e = float(v_e[0]), float(v_e[1]) + t0 = anc_info['timestamp'] * 1e-6 + + for gi_e, bi_e in meta['extrap_raw']: + dt = infos[gi_e]['timestamp'] * 1e-6 - t0 + cv_frames.append((x0_e + vx_e * dt, y0_e + vy_e * dt)) + + ext_e = [(f[1], f[2]) for f in frames_sorted if f[6] == _C_EXTRAP] + if cv_frames and len(cv_frames) == len(ext_e): + diffs = [np.hypot(e[0] - c[0], e[1] - c[1]) + for e, c in zip(ext_e, cv_frames)] + max_cv_ml_diff = float(max(diffs)) + + # Class name from last observed frame (or first extrap/interp). + class_name = '' + for src, idx in [(meta['obs_raw'], -1), (meta['extrap_raw'], 0)]: + if src: + gi_c, bi_c = src[idx] + class_name = infos[gi_c]['gt_names'][bi_c] + break + + tracks.append({ + 'inst_ind': inst_ind, + 'scene_token': sc_tok, + 'first_sample_token': infos[gidxs[0]]['token'], + 'class_name': class_name, + 'frames': frames_sorted, # (frame_idx, x_e, y_e, l, w, yaw_e, color) + 'cv_frames': cv_frames, # (x_e, y_e) per extrap frame + 'n_interp': n_interp, + 'n_extrap': n_extrap, + 'extrap_dist': extrap_dist, + 'interp_dist': interp_dist, + 'max_cv_ml_diff': max_cv_ml_diff, + }) + return tracks + + +def _visualize_single_track_ax(track, ax, show_cv=False, nusc=None): + """Render one instance track using the same style as ``_visualize_scene``. + + Draws track lines between adjacent frames and oriented boxes at every + frame position. The view is auto-zoomed to the track's spatial extent. + Coordinates are already in first-ego frame (set by ``_collect_all_tracks``). + + If *nusc* is provided, a map background is rendered first using the scene's + first-frame LIDAR_TOP token. The auto-zoom axis limits are always applied + last so the view is unchanged. + """ + import matplotlib.patches as mpatches + import matplotlib.lines as mlines + + frames = track['frames'] # (frame_idx, x_e, y_e, l, w, yaw_e, color) + cv_frames = track['cv_frames'] + + if not frames: + return + + # --- Compute auto-zoom bounds first (needed for map axes_limit) ---------- + all_xy = [(f[1], f[2]) for f in frames] + if show_cv: + all_xy.extend(cv_frames) + xs = [p[0] for p in all_xy] + ys = [p[1] for p in all_xy] + span = max(max(xs) - min(xs), max(ys) - min(ys), 8.0) + pad = span * 0.20 + cx = (max(xs) + min(xs)) / 2 + cy = (max(ys) + min(ys)) / 2 + xlim = (cx - span / 2 - pad, cx + span / 2 + pad) + ylim = (cy - span / 2 - pad, cy + span / 2 + pad) + + # --- Map background (rendered before track elements) --------------------- + ax.set_facecolor(_C_BG) + if nusc is not None: + try: + first_token = track.get('first_sample_token') + if first_token: + lidar_token = nusc.get('sample', first_token)['data']['LIDAR_TOP'] + # axes_limit must reach the farthest corner of the track from + # the scene origin (0, 0) so the full map tile is loaded. + map_limit = max( + abs(cx) + span / 2 + pad, + abs(cy) + span / 2 + pad, + ) + 20.0 + nusc.explorer.render_ego_centric_map( + sample_data_token=lidar_token, + axes_limit=map_limit, + ax=ax) + except Exception: + pass # map unavailable — continue without it + + # --- Track lines (all consecutive pairs; dashed across frame gaps) ------- + for k in range(len(frames) - 1): + f0, f1 = frames[k], frames[k + 1] + gap = f1[0] - f0[0] + ax.plot([f0[1], f1[1]], [f0[2], f1[2]], + color=f0[6], lw=0.7, + alpha=0.35 if gap > 1 else 0.55, + linestyle='--' if gap > 1 else '-', + zorder=2, solid_capstyle='round') + + # --- Boxes back-to-front (matching _visualize_scene render order) -------- + _alpha_face = {_C_EXTRAP: 0.13, _C_INTERP: 0.22, _C_INVALID: 0.18, _C_OBS: 0.22} + _alpha_edge = {_C_EXTRAP: 0.55, _C_INTERP: 0.85, _C_INVALID: 0.70, _C_OBS: 0.85} + _lw_d = {_C_EXTRAP: 0.5, _C_INTERP: 0.8, _C_INVALID: 0.7, _C_OBS: 0.8} + for target_color in [_C_EXTRAP, _C_INTERP, _C_INVALID, _C_OBS]: + for f in frames: + if f[6] != target_color: + continue + _draw_box(ax, f[1], f[2], f[3], f[4], f[5], color=f[6], + alpha_face=_alpha_face[f[6]], + alpha_edge=_alpha_edge[f[6]], + lw=_lw_d[f[6]]) + + # --- CV baseline --------------------------------------------------------- + if show_cv and len(cv_frames) >= 2: + cv_xs = [p[0] for p in cv_frames] + cv_ys = [p[1] for p in cv_frames] + ax.plot(cv_xs, cv_ys, color='#ffd54f', lw=1.0, ls='--', + marker='.', ms=2, alpha=0.75, zorder=3) + + # Star at the obs→extrap/interp handover point. + obs_frames = [f for f in frames if f[6] in (_C_OBS, _C_INVALID)] + if obs_frames: + lf = obs_frames[-1] + ax.scatter(lf[1], lf[2], color='#333333', s=22, zorder=5, + marker='*', linewidths=0) + + # --- Apply auto-zoom limits (identical whether map was rendered or not) -- + ax.set_xlim(*xlim) + ax.set_ylim(*ylim) + ax.set_aspect('equal') + ax.grid(True, color='#bbbbbb', alpha=0.4, lw=0.5) + ax.tick_params(colors='#444444', labelsize=6) + for spine in ax.spines.values(): + spine.set_color('#cccccc') + + leg_handles = [ + mpatches.Patch(color=_C_OBS, label='Visible'), + mpatches.Patch(color=_C_INVALID, label='Occluded'), + mpatches.Patch(color=_C_INTERP, label='Interpolated'), + mpatches.Patch(color=_C_EXTRAP, label='Extrapolated'), + ] + if show_cv: + leg_handles.append(mlines.Line2D( + [], [], color='#ffd54f', ls='--', lw=1.2, label='CV baseline')) + ax.legend(handles=leg_handles, loc='upper right', framealpha=0.85, + fontsize=5.5, labelcolor='#111111', facecolor='white', + edgecolor='#cccccc', ncol=1) + + cls = track['class_name'] + tok = track['scene_token'][:6] + n_e = track['n_extrap'] + n_i = track['n_interp'] + d_e = track.get('extrap_dist', 0.0) + d_i = track.get('interp_dist', 0.0) + diff = track['max_cv_ml_diff'] + parts = [f'{cls}', f'{tok}…'] + if n_i > 0: + parts.append(f'{n_i} interp ({d_i:.1f} m)') + if n_e > 0: + parts.append(f'{n_e} extrap ({d_e:.1f} m)') + if show_cv: + parts.append(f'max CV Δ={diff:.1f} m') + ax.set_title(' | '.join(parts), color='#111111', fontsize=7, pad=3) + + +def _save_single_track_grid(tracks, out_path, suptitle, score_fn, + top_n=12, show_cv=False, nusc=None): + """Save a 4-column grid of single-track BEV plots ranked by *score_fn*.""" + import matplotlib.pyplot as plt + + ranked = sorted(tracks, key=score_fn, reverse=True)[:top_n] + if not ranked: + print(f'No tracks for {out_path}, skipping.') + return + + ncols = 4 + nrows = (len(ranked) + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, + figsize=(5 * ncols, 5 * nrows), + facecolor=_C_BG) + axes = np.array(axes).flatten() + for i, track in enumerate(ranked): + _visualize_single_track_ax(track, axes[i], show_cv=show_cv, nusc=nusc) + for j in range(len(ranked), len(axes)): + axes[j].set_visible(False) + + fig.suptitle(suptitle, color='#111111', fontsize=11, y=1.01) + fig.subplots_adjust(hspace=0.5, wspace=0.3) + plt.savefig(out_path, dpi=200, bbox_inches='tight', + facecolor=fig.get_facecolor()) + print(f'Saved → {out_path}') + plt.close(fig) + + +# =========================================================================== +# Ego-distance histogram helper +# =========================================================================== + +def _count_missing_under_dist(infos, all_scenes, max_dist): + """Count instances whose last annotation ends before the scene and within max_dist. + + For each instance in each scene, finds the last frame where it appears + (observed or extrapolated). If that frame is not the final scene frame AND + the ego distance at that position is below *max_dist*, the instance is + counted as still-missing — our augmentation did not reach the scene end. + + Returns + ------- + (n_instances, n_frames) : + n_instances — number of such instance-scene pairs. + n_frames — total unannotated frame-slots those instances leave behind. + """ + n_instances = 0 + n_gaps = 0 + for gidxs in all_scenes: + if len(gidxs) <= 1: + continue + inst_last = {} # inst_ind → (frame_pos, gi, bi) of last annotation + for frame_pos, gi in enumerate(gidxs): + info = infos[gi] + for bi, inst_ind in enumerate(info['instance_inds']): + inst_last[inst_ind] = (frame_pos, gi, bi) + for inst_ind, (fp, gi, bi) in inst_last.items(): + if fp >= len(gidxs) - 1: + continue # reaches the scene end — not missing + box = infos[gi]['gt_boxes'][bi] + if float(np.hypot(box[0], box[1])) < max_dist: + n_instances += 1 + n_gaps += len(gidxs) - 1 - fp + return n_instances, n_gaps + + +def _save_ego_dist_hist(infos, out_path, max_dist=60, n_missing=None): + """Overlay histograms of ego-distance for original vs augmented annotations. + + Boxes are stored in the lidar frame whose origin is the ego vehicle, so + ego distance = hypot(box_x, box_y) directly. + + * 'Original' — boxes where both is_interpolated and is_extrapolated are False. + * 'Augmented' — all boxes (original + interpolated + extrapolated). + + Parameters + ---------- + max_dist : upper x-axis limit in metres (bins span [0, max_dist]). + n_missing : optional (n_instances, n_frames) from ``_count_missing_under_dist`` + — when provided, annotated as text on the plot. + """ + import matplotlib.pyplot as plt + + orig_dists, aug_dists = [], [] + for info in infos: + boxes = info['gt_boxes'] + is_i = info['is_interpolated'] + is_e = info['is_extrapolated'] + for bi in range(len(info['instance_inds'])): + d = float(np.hypot(boxes[bi, 0], boxes[bi, 1])) + aug_dists.append(d) + if not is_i[bi] and not is_e[bi]: + orig_dists.append(d) + + orig_dists = np.array(orig_dists) + aug_dists = np.array(aug_dists) + + bins = np.linspace(0, max_dist, 51) + + fig, ax = plt.subplots(figsize=(10, 5), facecolor=_C_BG) + ax.set_facecolor(_C_BG) + orig_filt = orig_dists[orig_dists <= max_dist] + aug_filt = aug_dists[aug_dists <= max_dist] + ax.hist(orig_filt, bins=bins, color=_C_OBS, alpha=0.75, + label=f'Original (n={len(orig_filt):,})') + ax.hist(aug_filt, bins=bins, color=_C_EXTRAP, alpha=0.55, + label=f'Augmented (n={len(aug_filt):,})') + + ax.set_xlabel('Ego distance (m)', color='#444444', fontsize=10) + ax.set_ylabel('Count', color='#444444', fontsize=10) + ax.set_title( + f'Ego-distance distribution (0–{max_dist} m) — original vs augmented', + color='#111111', fontsize=11, pad=6) + ax.tick_params(colors='#444444') + ax.legend(labelcolor='#111111', facecolor='white', edgecolor='#cccccc', + framealpha=0.9, fontsize=10) + for spine in ax.spines.values(): + spine.set_color('#cccccc') + ax.grid(True, color='#bbbbbb', alpha=0.4, linewidth=0.5) + + if n_missing is not None: + n_inst, n_frm = n_missing + print(f'Still missing (< {max_dist} m): ' + f'{n_inst:,} instances, {n_frm:,} frame-slots') + + plt.tight_layout() + plt.savefig(out_path, dpi=200, bbox_inches='tight', + facecolor=fig.get_facecolor()) + print(f'Saved → {out_path}') + plt.close(fig) + + +# =========================================================================== +# Visualize sub-command +# =========================================================================== + +def cmd_visualize(args): + import matplotlib.pyplot as plt + + print(f'Loading {args.pkl} ...') + with open(args.pkl, 'rb') as f: + data = pickle.load(f) + infos = data['infos'] + + nusc = None + if getattr(args, 'nuscenes_dataroot', None): + try: + from nuscenes import NuScenes + print(f'Loading NuScenes from {args.nuscenes_dataroot} ' + f'(version={args.nuscenes_version}) for map rendering …') + nusc = NuScenes(version=args.nuscenes_version, + dataroot=args.nuscenes_dataroot, verbose=False) + print(' NuScenes loaded.') + except Exception as e: + print(f' WARNING: could not load NuScenes ({e}); map will be skipped.') + + scene_map = defaultdict(list) + for gi, info in enumerate(infos): + scene_map[info['scene_token']].append(gi) + for sc in scene_map: + scene_map[sc].sort(key=lambda i: infos[i]['timestamp']) + all_scenes = list(scene_map.values()) + + os.makedirs(args.output_dir, exist_ok=True) + + if args.scene_idx is not None: + scene_list = [all_scenes[args.scene_idx]] + fig, ax = plt.subplots(figsize=(10, 10), facecolor=_C_BG) + gidxs = scene_list[0] + sc_tok = infos[gidxs[0]]['scene_token'] + _visualize_scene(infos, gidxs, ax, title=f'Scene {sc_tok[:8]}…', nusc=nusc) + out = args.output or os.path.join(args.output_dir, 'occ_viz.png') + plt.tight_layout() + plt.savefig(out, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor()) + print(f'Saved → {out}') + plt.close(fig) + return + + def _score_interp(gidxs): + return sum(int(infos[gi]['is_interpolated'].sum()) for gi in gidxs) + + def _score_extrap(gidxs): + return sum(int(infos[gi]['is_extrapolated'].sum()) for gi in gidxs) + + def _score_invalid(gidxs): + total = 0 + for gi in gidxs: + info = infos[gi] + obs_mask = ~info['is_interpolated'] & ~info['is_extrapolated'] + total += int((obs_mask & ~info['valid_flag']).sum()) + return total + + top_interp = sorted(all_scenes, key=_score_interp, reverse=True)[:args.num_scenes] + top_extrap = sorted(all_scenes, key=_score_extrap, reverse=True)[:args.num_scenes] + top_invalid = sorted(all_scenes, key=_score_invalid, reverse=True)[:args.num_scenes] + + def _save_grid(scene_list, out_path, suptitle, show_forecast=False): + ncols = min(3, len(scene_list)) + nrows = (len(scene_list) + ncols - 1) // ncols + fig, axes = plt.subplots(nrows, ncols, figsize=(7*ncols, 7*nrows), facecolor=_C_BG) + axes = np.array(axes).flatten() + for i, gidxs in enumerate(scene_list): + sc_tok = infos[gidxs[0]]['scene_token'] + _visualize_scene(infos, gidxs, axes[i], + title=f'Scene {sc_tok[:8]}…', + show_forecast=show_forecast, nusc=nusc) + for j in range(len(scene_list), len(axes)): + axes[j].set_visible(False) + fig.suptitle(suptitle, color='#111111', fontsize=11, y=1.01) + fig.subplots_adjust(hspace=0.35, wspace=0.25) + plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor()) + print(f'Saved → {out_path}') + plt.close(fig) + + _save_grid(top_interp, + os.path.join(args.output_dir, 'occ_viz_interp.png'), + f'Top {args.num_scenes} scenes by interpolated annotation count') + _save_grid(top_extrap, + os.path.join(args.output_dir, 'occ_viz_extrap.png'), + f'Top {args.num_scenes} scenes by extrapolated annotation count') + _save_grid(top_invalid, + os.path.join(args.output_dir, 'occ_viz_invalid.png'), + f'Top {args.num_scenes} scenes by invalid (0-pt) observed annotation count') + # ------------------------------------------------------------------ + # Single-track grids + # ------------------------------------------------------------------ + print('Collecting per-instance track data …') + all_tracks = _collect_all_tracks(infos, all_scenes) + extrap_tracks = [t for t in all_tracks if t['n_extrap'] > 0] + interp_tracks = [t for t in all_tracks if t['n_interp'] > 0] + + _save_single_track_grid( + extrap_tracks, + os.path.join(args.output_dir, 'occ_viz_extrap_single.png'), + 'Top 12 extrapolated tracks — most distance covered during extrapolation', + score_fn=lambda t: t['extrap_dist'], + nusc=nusc, + ) + _save_single_track_grid( + interp_tracks, + os.path.join(args.output_dir, 'occ_viz_interp_single.png'), + 'Top 12 interpolated tracks — most distance covered during interpolation', + score_fn=lambda t: t['interp_dist'], + nusc=nusc, + ) + _save_single_track_grid( + extrap_tracks, + os.path.join(args.output_dir, 'occ_viz_extrap_single_mismatch.png'), + 'Top 12 extrapolated tracks — largest extrapolation-vs-CV position mismatch', + score_fn=lambda t: t['max_cv_ml_diff'], + show_cv=True, + nusc=nusc, + ) + _save_ego_dist_hist( + infos, + os.path.join(args.output_dir, 'occ_viz_ego_dist.png'), + ) + + +# =========================================================================== +# Entry point +# =========================================================================== + +def main(): + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter) + sub = parser.add_subparsers(dest='cmd', required=True) + + p_conv = sub.add_parser('convert', help='Run annotation pipeline and save pkl.') + p_conv.add_argument('--data-dir', default='data/infos', + help='Directory containing the nuScenes info pkls. ' + 'All three splits (train/val/test) are processed ' + 'unless --input is given.') + p_conv.add_argument('--input', default='data/infos/nuscenes_infos_val.pkl', + help='Single input pkl (overrides --data-dir loop)') + p_conv.add_argument('--output', default='data/infos/nuscenes_infos_val_occ.pkl', + help='Single output pkl (required when --input is set)') + p_conv.add_argument('--predictions', default='data/occlusions/fmae_nuscenes_v1trainval_3class_sw_inference.npz', + help='Path to UniTraj inference NPZ for ML-based ' + 'extrapolation; falls back to CV when no match') + p_conv.add_argument('--nuscenes-dataroot', default='data/nuscenes', + dest='nuscenes_dataroot', + help='nuScenes dataset root (contains v1.0-trainval/); ' + 'required when --predictions is used to map ' + 'integer instance indices to token strings') + p_conv.add_argument('--no-extrapolate', action='store_true', + help='Disable forward extrapolation') + p_conv.add_argument('--max-extrap-frames', type=int, default=9999, + help='Max frames to extrapolate forward per instance. ' + 'Default 9999 is effectively unlimited — the loop ' + 'always stops when the instance reappears or the ' + 'scene ends. Lower this to restrict to the ML ' + 'prediction horizon (e.g. 12 at 2 Hz).') + p_conv.add_argument('--unitraj-dt', type=float, default=0.1, + dest='unitraj_dt', + help='Seconds per UniTraj prediction step. ' + '0.1 for 10 Hz models (default), 0.5 for 2 Hz models.') + p_conv.add_argument('--cv-stationary-thr', type=float, default=0.3, + dest='cv_stationary_thr', + help='Speed threshold (m/s) below which the CV fallback ' + 'uses constant position instead of constant velocity ' + '(default: 0.3).') + # CA / CTR motion-model fitting + p_conv.add_argument('--min-ca-history', type=int, default=4, + dest='min_ca_history', + help='Minimum consecutive observed frames required to fit ' + 'a CA/CTR model (default: 3).') + p_conv.add_argument('--ca-noise-thr', type=float, default=0.5, + dest='ca_noise_thr', + help='Minimum |acceleration| (m/s²) to apply CA model; ' + 'below this the track is treated as constant velocity ' + '(default: 0.5).') + p_conv.add_argument('--ca-max-thr', type=float, default=6.0, + dest='ca_max_thr', + help='Maximum |acceleration| (m/s²) accepted as realistic; ' + 'above this CA is rejected (default: 5.0).') + p_conv.add_argument('--ca-consistency-thr', type=float, default=0.5, + dest='ca_consistency_thr', + help='Maximum standard deviation of per-interval acceleration ' + '(m/s²) for the CA model to be accepted (default: 0.5).') + p_conv.add_argument('--omega-noise-thr', type=float, default=0.07, + dest='omega_noise_thr', + help='Minimum |turn rate| (rad/s) to apply CTR model; ' + 'below this the track is treated as straight ' + '(default: 0.07 ≈ 4 deg/s).') + p_conv.add_argument('--omega-max-thr', type=float, default=0.6, + dest='omega_max_thr', + help='Maximum |turn rate| (rad/s) accepted as realistic; ' + 'above this CTR is rejected (default: 0.6 ≈ 34 deg/s).') + p_conv.add_argument('--omega-consistency-thr', type=float, default=0.12, + dest='omega_consistency_thr', + help='Maximum standard deviation of per-interval turn rate ' + '(rad/s) for the CTR model to be accepted (default: 0.12).') + p_conv.add_argument('--max-dist', type=float, default=60.0, + dest='max_dist', + help='Two-level ego-distance filter applied after conversion: ' + 'drops instances with no original obs within this range ' + 'and removes per-frame boxes beyond it. ' + 'Set to 0 or a negative value to disable (default: 60.0).') + + p_viz = sub.add_parser('visualize', help='BEV plots of occluded annotations.') + p_viz.add_argument('--pkl', default='data/infos/nuscenes_infos_val_occ.pkl') + p_viz.add_argument('--scene-idx', type=int, default=None, + help='Scene index (0-based)') + p_viz.add_argument('--num-scenes', type=int, default=6, + help='Number of example scenes to plot (default 6)') + p_viz.add_argument('--output-dir', default='vis/gt_occ', + help='Directory for output images') + p_viz.add_argument('--output', default=None, + help='Single output filename (overrides --output-dir)') + p_viz.add_argument('--nuscenes-dataroot', default='data/nuscenes', + dest='nuscenes_dataroot', + help='nuScenes dataset root (e.g. /data/sets/nuscenes). ' + 'When provided, drivable-area map is rendered behind ' + 'each BEV plot using nusc.explorer.render_ego_centric_map.') + p_viz.add_argument('--nuscenes-version', default='v1.0-trainval', + dest='nuscenes_version', + help='nuScenes version string (default: v1.0-trainval)') + + args = parser.parse_args() + if args.cmd == 'convert': + if args.input is not None: + if args.output is None: + parser.error('--output is required when --input is specified') + cmd_convert(args) + else: + _SPLITS = ('train', 'val') + for split in _SPLITS: + inp = os.path.join(args.data_dir, f'nuscenes_infos_{split}.pkl') + out = os.path.join(args.data_dir, f'nuscenes_infos_{split}_occ.pkl') + if not os.path.exists(inp): + print(f'Skipping {inp} (not found)') + continue + args.input = inp + args.output = out + print(f'\n=== {split} ===') + cmd_convert(args) + else: + cmd_visualize(args) + + +if __name__ == '__main__': + main() diff --git a/tools/dist_test.sh b/tools/dist_test.sh new file mode 100644 index 0000000..033365e --- /dev/null +++ b/tools/dist_test.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +CONFIG=$1 +CHECKPOINT=$2 +GPUS=$3 +PORT=${PORT:-29610} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} diff --git a/tools/dist_train.sh b/tools/dist_train.sh new file mode 100644 index 0000000..43e95de --- /dev/null +++ b/tools/dist_train.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +CONFIG=$1 +GPUS=$2 +PORT=${PORT:-28651} + +PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ +python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ + $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} diff --git a/tools/fuse_conv_bn.py b/tools/fuse_conv_bn.py new file mode 100644 index 0000000..9aff402 --- /dev/null +++ b/tools/fuse_conv_bn.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import torch +from mmcv.runner import save_checkpoint +from torch import nn as nn + +from mmdet3d.apis import init_model + + +def fuse_conv_bn(conv, bn): + """During inference, the functionary of batch norm layers is turned off but + only the mean and var alone channels are used, which exposes the chance to + fuse it with the preceding conv layers to save computations and simplify + network structures.""" + conv_w = conv.weight + conv_b = conv.bias if conv.bias is not None else torch.zeros_like( + bn.running_mean) + + factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) + conv.weight = nn.Parameter(conv_w * + factor.reshape([conv.out_channels, 1, 1, 1])) + conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) + return conv + + +def fuse_module(m): + last_conv = None + last_conv_name = None + + for name, child in m.named_children(): + if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)): + if last_conv is None: # only fuse BN that is after Conv + continue + fused_conv = fuse_conv_bn(last_conv, child) + m._modules[last_conv_name] = fused_conv + # To reduce changes, set BN as Identity instead of deleting it. + m._modules[name] = nn.Identity() + last_conv = None + elif isinstance(child, nn.Conv2d): + last_conv = child + last_conv_name = name + else: + fuse_module(child) + return m + + +def parse_args(): + parser = argparse.ArgumentParser( + description='fuse Conv and BN layers in a model') + parser.add_argument('config', help='config file path') + parser.add_argument('checkpoint', help='checkpoint file path') + parser.add_argument('out', help='output path of the converted model') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + # build the model from a config file and a checkpoint file + model = init_model(args.config, args.checkpoint) + # fuse conv and bn layers of the model + fused_model = fuse_module(model) + save_checkpoint(fused_model, args.out) + + +if __name__ == '__main__': + main() diff --git a/tools/kmeans/kmeans_det.py b/tools/kmeans/kmeans_det.py new file mode 100644 index 0000000..3c8f218 --- /dev/null +++ b/tools/kmeans/kmeans_det.py @@ -0,0 +1,34 @@ +import os +import pickle +from tqdm import tqdm + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans + +import mmcv + +os.makedirs('data/kmeans', exist_ok=True) +os.makedirs('vis/kmeans', exist_ok=True) + +K = 900 +DIS_THRESH = 55 + +fp = 'data/infos/nuscenes_infos_train.pkl' +data = mmcv.load(fp) +data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) +center = [] +for idx in tqdm(range(len(data_infos))): + boxes = data_infos[idx]['gt_boxes'][:,:3] + if len(boxes) == 0: + continue + distance = np.linalg.norm(boxes[:, :2], axis=1) + center.append(boxes[distance < DIS_THRESH]) +center = np.concatenate(center, axis=0) +print("start clustering, may take a few minutes.") +cluster = KMeans(n_clusters=K).fit(center).cluster_centers_ +plt.scatter(cluster[:,0], cluster[:,1]) +plt.savefig(f'vis/kmeans/det_anchor_{K}', bbox_inches='tight') +others = np.array([1,1,1,1,0,0,0,0])[np.newaxis].repeat(K, axis=0) +cluster = np.concatenate([cluster, others], axis=1) +np.save(f'data/kmeans/kmeans_det_{K}.npy', cluster) \ No newline at end of file diff --git a/tools/kmeans/kmeans_map.py b/tools/kmeans/kmeans_map.py new file mode 100644 index 0000000..13b34bc --- /dev/null +++ b/tools/kmeans/kmeans_map.py @@ -0,0 +1,34 @@ +import os +import pickle +from tqdm import tqdm + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans + +import mmcv + +K = 100 +num_sample = 20 + +fp = 'data/infos/nuscenes_infos_train.pkl' +data = mmcv.load(fp) +data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) +center = [] +for idx in tqdm(range(len(data_infos))): + for cls, geoms in data_infos[idx]["map_annos"].items(): + for geom in geoms: + center.append(geom.mean(axis=0)) +center = np.stack(center, axis=0) +center = KMeans(n_clusters=K).fit(center).cluster_centers_ +delta_y = np.linspace(-4, 4, num_sample) +delta_x = np.zeros([num_sample]) +delta = np.stack([delta_x, delta_y], axis=-1) +vecs = center[:, np.newaxis] + delta[np.newaxis] + +for i in range(K): + x = vecs[i, :, 0] + y = vecs[i, :, 1] + plt.plot(x, y, linewidth=1, marker='o', linestyle='-', markersize=2) +plt.savefig(f'vis/kmeans/map_anchor_{K}', bbox_inches='tight') +np.save(f'data/kmeans/kmeans_map_{K}.npy', vecs) \ No newline at end of file diff --git a/tools/kmeans/kmeans_motion.py b/tools/kmeans/kmeans_motion.py new file mode 100644 index 0000000..1c5e74c --- /dev/null +++ b/tools/kmeans/kmeans_motion.py @@ -0,0 +1,101 @@ +import os +import pickle +from tqdm import tqdm + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans + +import mmcv + +CLASSES = [ + "car", + "truck", + "construction_vehicle", + "bus", + "trailer", + "barrier", + "motorcycle", + "bicycle", + "pedestrian", + "traffic_cone", +] + +def lidar2agent(trajs_offset, boxes): + origin = np.zeros((trajs_offset.shape[0], 1, 2), dtype=np.float32) + trajs_offset = np.concatenate([origin, trajs_offset], axis=1) + trajs = trajs_offset.cumsum(axis=1) + yaws = - boxes[:, 6] + rot_sin = np.sin(yaws) + rot_cos = np.cos(yaws) + rot_mat_T = np.stack( + [ + np.stack([rot_cos, rot_sin]), + np.stack([-rot_sin, rot_cos]), + ] + ) + trajs_new = np.einsum('aij,jka->aik', trajs, rot_mat_T) + trajs_new = trajs_new[:, 1:] + return trajs_new + +K = 6 +DIS_THRESH = 55 + +fp = 'data/infos/nuscenes_infos_train.pkl' +data = mmcv.load(fp) +data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) +intention = dict() +for i in range(len(CLASSES)): + intention[i] = [] +for idx in tqdm(range(len(data_infos))): + info = data_infos[idx] + boxes = info['gt_boxes'] + names = info['gt_names'] + fut_masks = info['gt_agent_fut_masks'] + trajs = info['gt_agent_fut_trajs'] + velos = info['gt_velocity'] + labels = [] + for cat in names: + if cat in CLASSES: + labels.append(CLASSES.index(cat)) + else: + labels.append(-1) + labels = np.array(labels) + if len(boxes) == 0: + continue + for i in range(len(CLASSES)): + cls_mask = (labels == i) + box_cls = boxes[cls_mask] + fut_masks_cls = fut_masks[cls_mask] + trajs_cls = trajs[cls_mask] + velos_cls = velos[cls_mask] + + distance = np.linalg.norm(box_cls[:, :2], axis=1) + mask = np.logical_and( + fut_masks_cls.sum(axis=1) == 12, + distance < DIS_THRESH, + ) + trajs_cls = trajs_cls[mask] + box_cls = box_cls[mask] + velos_cls = velos_cls[mask] + + trajs_agent = lidar2agent(trajs_cls, box_cls) + if trajs_agent.shape[0] == 0: + continue + intention[i].append(trajs_agent) + +clusters = [] +for i in range(len(CLASSES)): + intention_cls = np.concatenate(intention[i], axis=0).reshape(-1, 24) + if intention_cls.shape[0] < K: + continue + cluster = KMeans(n_clusters=K).fit(intention_cls).cluster_centers_ + cluster = cluster.reshape(-1, 12, 2) + clusters.append(cluster) + for j in range(K): + plt.scatter(cluster[j, :, 0], cluster[j, :,1]) + plt.savefig(f'vis/kmeans/motion_intention_{CLASSES[i]}_{K}', bbox_inches='tight') + plt.close() + +clusters = np.stack(clusters, axis=0) +np.save(f'data/kmeans/kmeans_motion_{K}.npy', clusters) \ No newline at end of file diff --git a/tools/kmeans/kmeans_plan.py b/tools/kmeans/kmeans_plan.py new file mode 100644 index 0000000..33a74f7 --- /dev/null +++ b/tools/kmeans/kmeans_plan.py @@ -0,0 +1,39 @@ +import os +import pickle +from tqdm import tqdm + +import numpy as np +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans + +import mmcv + +K = 6 + +fp = 'data/infos/nuscenes_infos_train.pkl' +data = mmcv.load(fp) +data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) +navi_trajs = [[], [], []] +for idx in tqdm(range(len(data_infos))): + info = data_infos[idx] + plan_traj = info['gt_ego_fut_trajs'].cumsum(axis=-2) + plan_mask = info['gt_ego_fut_masks'] + cmd = info['gt_ego_fut_cmd'].astype(np.int32) + cmd = cmd.argmax(axis=-1) + if not plan_mask.sum() == 6: + continue + navi_trajs[cmd].append(plan_traj) + +clusters = [] +for trajs in navi_trajs: + trajs = np.concatenate(trajs, axis=0).reshape(-1, 12) + cluster = KMeans(n_clusters=K).fit(trajs).cluster_centers_ + cluster = cluster.reshape(-1, 6, 2) + clusters.append(cluster) + for j in range(K): + plt.scatter(cluster[j, :, 0], cluster[j, :,1]) +plt.savefig(f'vis/kmeans/plan_{K}', bbox_inches='tight') +plt.close() + +clusters = np.stack(clusters, axis=0) +np.save(f'data/kmeans/kmeans_plan_{K}.npy', clusters) \ No newline at end of file diff --git a/tools/test.py b/tools/test.py new file mode 100644 index 0000000..6a27db2 --- /dev/null +++ b/tools/test.py @@ -0,0 +1,334 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import mmcv +import os +from os import path as osp + +import torch +import warnings +from mmcv import Config, DictAction +from mmcv.cnn import fuse_conv_bn +from mmcv.parallel import MMDataParallel, MMDistributedDataParallel +from mmcv.runner import ( + get_dist_info, + init_dist, + load_checkpoint, + wrap_fp16_model, +) + +from mmdet.apis import single_gpu_test, multi_gpu_test, set_random_seed +from mmdet.datasets import replace_ImageToTensor, build_dataset +from mmdet.datasets import build_dataloader as build_dataloader_origin +from mmdet.models import build_detector + +from projects.mmdet3d_plugin.datasets.builder import build_dataloader +from projects.mmdet3d_plugin.apis.test import custom_multi_gpu_test + + +def parse_args(): + parser = argparse.ArgumentParser( + description="MMDet test (and eval) a model" + ) + parser.add_argument("config", help="test config file path") + parser.add_argument("checkpoint", help="checkpoint file") + parser.add_argument("--out", help="output result file in pickle format") + parser.add_argument( + "--fuse-conv-bn", + action="store_true", + help="Whether to fuse conv and bn, this will slightly increase" + "the inference speed", + ) + parser.add_argument( + "--format-only", + action="store_true", + help="Format the output results without perform evaluation. It is" + "useful when you want to format the result to a specific format and " + "submit it to the test server", + ) + parser.add_argument( + "--eval", + type=str, + nargs="+", + help='evaluation metrics, which depends on the dataset, e.g., "bbox",' + ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC', + ) + parser.add_argument("--show", action="store_true", help="show results") + parser.add_argument( + "--show-dir", help="directory where results will be saved" + ) + parser.add_argument( + "--gpu-collect", + action="store_true", + help="whether to use gpu to collect results.", + ) + parser.add_argument( + "--tmpdir", + help="tmp directory used for collecting results from multiple " + "workers, available when gpu-collect is not specified", + ) + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument( + "--deterministic", + action="store_true", + help="whether to set deterministic options for CUDNN backend.", + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument( + "--options", + nargs="+", + action=DictAction, + help="custom options for evaluation, the key-value pair in xxx=yyy " + "format will be kwargs for dataset.evaluate() function (deprecate), " + "change to --eval-options instead.", + ) + parser.add_argument( + "--eval-options", + nargs="+", + action=DictAction, + help="custom options for evaluation, the key-value pair in xxx=yyy " + "format will be kwargs for dataset.evaluate() function", + ) + parser.add_argument( + "--launcher", + choices=["none", "pytorch", "slurm", "mpi"], + default="none", + help="job launcher", + ) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument("--result_file", type=str, default=None) + parser.add_argument("--show_only", action="store_true") + args = parser.parse_args() + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + if args.options and args.eval_options: + raise ValueError( + "--options and --eval-options cannot be both specified, " + "--options is deprecated in favor of --eval-options" + ) + if args.options: + warnings.warn("--options is deprecated in favor of --eval-options") + args.eval_options = args.options + return args + + +def main(): + args = parse_args() + + assert ( + args.out or args.eval or args.format_only or args.show or args.show_dir + ), ( + "Please specify at least one operation (save/eval/format/show the " + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"' + ) + + if args.eval and args.format_only: + raise ValueError("--eval and --format_only cannot be both specified") + + if args.out is not None and not args.out.endswith((".pkl", ".pickle")): + raise ValueError("The output file must be a pkl file.") + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get("custom_imports", None): + from mmcv.utils import import_modules_from_strings + + import_modules_from_strings(**cfg["custom_imports"]) + + # import modules from plguin/xx, registry will be updated + if hasattr(cfg, "plugin"): + if cfg.plugin: + import importlib + + if hasattr(cfg, "plugin_dir"): + plugin_dir = cfg.plugin_dir + _module_dir = os.path.dirname(plugin_dir) + _module_dir = _module_dir.split("/") + _module_path = _module_dir[0] + + for m in _module_dir[1:]: + _module_path = _module_path + "." + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + else: + # import dir is the dirpath for the config file + _module_dir = os.path.dirname(args.config) + _module_dir = _module_dir.split("/") + _module_path = _module_dir[0] + for m in _module_dir[1:]: + _module_path = _module_path + "." + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + + # set cudnn_benchmark + if cfg.get("cudnn_benchmark", False): + torch.backends.cudnn.benchmark = True + + cfg.model.pretrained = None + # in case the test dataset is concatenated + samples_per_gpu = 1 + if isinstance(cfg.data.test, dict): + cfg.data.test.test_mode = True + samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1) + if samples_per_gpu > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.test.pipeline = replace_ImageToTensor( + cfg.data.test.pipeline + ) + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + ds_cfg.test_mode = True + samples_per_gpu = max( + [ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test] + ) + if samples_per_gpu > 1: + for ds_cfg in cfg.data.test: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + + # init distributed env first, since logger depends on the dist info. + if args.launcher == "none": + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + # set random seeds + if args.seed is not None: + set_random_seed(args.seed, deterministic=args.deterministic) + + # set work dir + if cfg.get('work_dir', None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join('./work_dirs', + osp.splitext(osp.basename(args.config))[0]) + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + cfg.data.test.work_dir = cfg.work_dir + print('work_dir: ',cfg.work_dir) + + # build the dataloader + dataset = build_dataset(cfg.data.test) + print("distributed:", distributed) + if distributed: + data_loader = build_dataloader( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + nonshuffler_sampler=dict(type="DistributedSampler"), + ) + else: + data_loader = build_dataloader_origin( + dataset, + samples_per_gpu=samples_per_gpu, + workers_per_gpu=cfg.data.workers_per_gpu, + dist=distributed, + shuffle=False, + ) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg")) + # model = build_model(cfg.model, test_cfg=cfg.get("test_cfg")) + fp16_cfg = cfg.get("fp16", None) + if fp16_cfg is not None: + wrap_fp16_model(model) + checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") + if args.fuse_conv_bn: + model = fuse_conv_bn(model) + # old versions did not save class info in checkpoints, this walkaround is + # for backward compatibility + if "CLASSES" in checkpoint.get("meta", {}): + model.CLASSES = checkpoint["meta"]["CLASSES"] + else: + model.CLASSES = dataset.CLASSES + # palette for visualization in segmentation tasks + if "PALETTE" in checkpoint.get("meta", {}): + model.PALETTE = checkpoint["meta"]["PALETTE"] + elif hasattr(dataset, "PALETTE"): + # segmentation dataset has `PALETTE` attribute + model.PALETTE = dataset.PALETTE + + if args.result_file is not None: + # outputs = torch.load(args.result_file) + outputs = mmcv.load(args.result_file) + elif not distributed: + model = MMDataParallel(model, device_ids=[0]) + outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) + else: + model = MMDistributedDataParallel( + model.cuda(), + device_ids=[torch.cuda.current_device()], + broadcast_buffers=False, + ) + outputs = custom_multi_gpu_test( + model, data_loader, args.tmpdir, args.gpu_collect + ) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + print(f"\nwriting results to {args.out}") + mmcv.dump(outputs, args.out) + kwargs = {} if args.eval_options is None else args.eval_options + if args.show_only: + eval_kwargs = cfg.get("evaluation", {}).copy() + # hard-code way to remove EvalHook args + for key in [ + "interval", + "tmpdir", + "start", + "gpu_collect", + "save_best", + "rule", + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(kwargs) + dataset.show(outputs, show=True, **eval_kwargs) + elif args.format_only: + dataset.format_results(outputs, **kwargs) + elif args.eval: + eval_kwargs = cfg.get("evaluation", {}).copy() + # hard-code way to remove EvalHook args + for key in [ + "interval", + "tmpdir", + "start", + "gpu_collect", + "save_best", + "rule", + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=args.eval, **kwargs)) + print(eval_kwargs) + results_dict = dataset.evaluate(outputs, **eval_kwargs) + print(results_dict) + + # Log to wandb if config has WandbLoggerHook + for hook in cfg.log_config.hooks: + if hook.type == "WandbLoggerHook": + import wandb + if wandb.run is None: + wandb.init(**hook.init_kwargs) + wandb.log({'val/' + k: v for k, v in results_dict.items()}) + break + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method( + "fork" + ) # use fork workers_per_gpu can be > 1 + main() diff --git a/tools/train.py b/tools/train.py new file mode 100644 index 0000000..3ad94a6 --- /dev/null +++ b/tools/train.py @@ -0,0 +1,321 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from __future__ import division +import sys +import os + +print(sys.executable, os.path.abspath(__file__)) +# import init_paths # for conda pkgs submitting method +import argparse +import copy +import mmcv +import time +import torch +import warnings +from mmcv import Config, DictAction +from mmcv.runner import get_dist_info, init_dist +from os import path as osp + +from mmdet import __version__ as mmdet_version +from mmdet.apis import train_detector +from mmdet.datasets import build_dataset +from mmdet.models import build_detector +from mmdet.utils import collect_env, get_root_logger +from mmdet.apis import set_random_seed +from torch import distributed as dist +from datetime import timedelta + +import cv2 + +cv2.setNumThreads(8) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Train a detector") + parser.add_argument("config", help="train config file path") + parser.add_argument("--work-dir", help="the dir to save logs and models") + parser.add_argument( + "--resume-from", help="the checkpoint file to resume from" + ) + parser.add_argument( + "--no-validate", + action="store_true", + help="whether not to evaluate the checkpoint during training", + ) + group_gpus = parser.add_mutually_exclusive_group() + group_gpus.add_argument( + "--gpus", + type=int, + help="number of gpus to use " + "(only applicable to non-distributed training)", + ) + group_gpus.add_argument( + "--gpu-ids", + type=int, + nargs="+", + help="ids of gpus to use " + "(only applicable to non-distributed training)", + ) + parser.add_argument("--seed", type=int, default=0, help="random seed") + parser.add_argument( + "--deterministic", + action="store_true", + help="whether to set deterministic options for CUDNN backend.", + ) + parser.add_argument( + "--options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + parser.add_argument( + "--cfg-options", + nargs="+", + action=DictAction, + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file. If the value to " + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + "Note that the quotation marks are necessary and that no white space " + "is allowed.", + ) + parser.add_argument( + "--dist-url", + type=str, + default="auto", + help="dist url for init process, such as tcp://localhost:8000", + ) + parser.add_argument("--gpus-per-machine", type=int, default=8) + parser.add_argument( + "--launcher", + choices=["none", "pytorch", "slurm", "mpi", "mpi_nccl"], + default="none", + help="job launcher", + ) + parser.add_argument("--local_rank", type=int, default=0) + parser.add_argument( + "--autoscale-lr", + action="store_true", + help="automatically scale lr with the number of gpus", + ) + args = parser.parse_args() + if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(args.local_rank) + + if args.options and args.cfg_options: + raise ValueError( + "--options and --cfg-options cannot be both specified, " + "--options is deprecated in favor of --cfg-options" + ) + if args.options: + warnings.warn("--options is deprecated in favor of --cfg-options") + args.cfg_options = args.options + + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + # import modules from string list. + if cfg.get("custom_imports", None): + from mmcv.utils import import_modules_from_strings + + import_modules_from_strings(**cfg["custom_imports"]) + + # import modules from plguin/xx, registry will be updated + if hasattr(cfg, "plugin"): + if cfg.plugin: + import importlib + + if hasattr(cfg, "plugin_dir"): + plugin_dir = cfg.plugin_dir + _module_dir = os.path.dirname(plugin_dir) + _module_dir = _module_dir.split("/") + _module_path = _module_dir[0] + + for m in _module_dir[1:]: + _module_path = _module_path + "." + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + else: + # import dir is the dirpath for the config file + _module_dir = os.path.dirname(args.config) + _module_dir = _module_dir.split("/") + _module_path = _module_dir[0] + for m in _module_dir[1:]: + _module_path = _module_path + "." + m + print(_module_path) + plg_lib = importlib.import_module(_module_path) + from projects.mmdet3d_plugin.apis.train import custom_train_model + + # set cudnn_benchmark + if cfg.get("cudnn_benchmark", False): + torch.backends.cudnn.benchmark = True + + # work_dir is determined in this priority: CLI > segment in file > filename + if args.work_dir is not None: + # update configs according to CLI args if args.work_dir is not None + cfg.work_dir = args.work_dir + elif cfg.get("work_dir", None) is None: + # use config filename as default work_dir if cfg.work_dir is None + cfg.work_dir = osp.join( + "./work_dirs", osp.splitext(osp.basename(args.config))[0] + ) + if args.resume_from is not None: + cfg.resume_from = args.resume_from + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids + else: + cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) + + if args.autoscale_lr: + # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) + cfg.optimizer["lr"] = cfg.optimizer["lr"] * len(cfg.gpu_ids) / 8 + + # init distributed env first, since logger depends on the dist info. + if args.launcher == "none": + distributed = False + elif args.launcher == "mpi_nccl": + distributed = True + + import mpi4py.MPI as MPI + + comm = MPI.COMM_WORLD + mpi_local_rank = comm.Get_rank() + mpi_world_size = comm.Get_size() + print( + "MPI local_rank=%d, world_size=%d" + % (mpi_local_rank, mpi_world_size) + ) + + # num_gpus = torch.cuda.device_count() + device_ids_on_machines = list(range(args.gpus_per_machine)) + str_ids = list(map(str, device_ids_on_machines)) + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str_ids) + torch.cuda.set_device(mpi_local_rank % args.gpus_per_machine) + + dist.init_process_group( + backend="nccl", + init_method=args.dist_url, + world_size=mpi_world_size, + rank=mpi_local_rank, + timeout=timedelta(seconds=7200), + ) + + cfg.gpu_ids = range(mpi_world_size) + print("cfg.gpu_ids:", cfg.gpu_ids) + else: + distributed = True + init_dist( + args.launcher, timeout=timedelta(seconds=7200), **cfg.dist_params + ) + # re-set gpu_ids with distributed training mode + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) + + # create work_dir + mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) + # dump config + cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) + # init the logger before other steps + timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) + log_file = osp.join(cfg.work_dir, f"{timestamp}.log") + # specify logger name, if we still use 'mmdet', the output info will be + # filtered and won't be saved in the log_file + # TODO: ugly workaround to judge whether we are training det or seg model + logger = get_root_logger( + log_file=log_file, log_level=cfg.log_level + ) + + # init the meta dict to record some important information such as + # environment info and seed, which will be logged + meta = dict() + # log env info + env_info_dict = collect_env() + env_info = "\n".join([(f"{k}: {v}") for k, v in env_info_dict.items()]) + dash_line = "-" * 60 + "\n" + logger.info( + "Environment info:\n" + dash_line + env_info + "\n" + dash_line + ) + meta["env_info"] = env_info + meta["config"] = cfg.pretty_text + + # log some basic info + logger.info(f"Distributed training: {distributed}") + logger.info(f"Config:\n{cfg.pretty_text}") + + # set random seeds + if args.seed is not None: + logger.info( + f"Set random seed to {args.seed}, " + f"deterministic: {args.deterministic}" + ) + set_random_seed(args.seed, deterministic=args.deterministic) + cfg.seed = args.seed + meta["seed"] = args.seed + meta["exp_name"] = osp.basename(args.config) + + model = build_detector( + cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") + ) + model.init_weights() + logger.info(f"Model:\n{model}") + + cfg.data.train.work_dir = cfg.work_dir + cfg.data.val.work_dir = cfg.work_dir + datasets = [build_dataset(cfg.data.train)] + + if len(cfg.workflow) == 2: + val_dataset = copy.deepcopy(cfg.data.val) + # in case we use a dataset wrapper + if "dataset" in cfg.data.train: + val_dataset.pipeline = cfg.data.train.dataset.pipeline + else: + val_dataset.pipeline = cfg.data.train.pipeline + # set test_mode=False here in deep copied config + # which do not affect AP/AR calculation later + # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa + val_dataset.test_mode = False + datasets.append(build_dataset(val_dataset)) + if cfg.checkpoint_config is not None: + # save mmdet version, config file content and class names in + # checkpoints as meta data + cfg.checkpoint_config.meta = dict( + mmdet_version=mmdet_version, + config=cfg.pretty_text, + CLASSES=datasets[0].CLASSES, + ) + # add an attribute for visualization convenience + model.CLASSES = datasets[0].CLASSES + if hasattr(cfg, "plugin"): + custom_train_model( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta, + ) + else: + train_detector( + model, + datasets, + cfg, + distributed=distributed, + validate=(not args.no_validate), + timestamp=timestamp, + meta=meta, + ) + + +if __name__ == "__main__": + torch.multiprocessing.set_start_method( + "fork" + ) # use fork workers_per_gpu can be > 1 + main() diff --git a/tools/upload_wandb_results.py b/tools/upload_wandb_results.py new file mode 100644 index 0000000..e33ff31 --- /dev/null +++ b/tools/upload_wandb_results.py @@ -0,0 +1,26 @@ +# Uploads pre-run test results to wandb as individual runs. +# Each entry in results_dict is a run name -> metrics dict. +# Run: python tools/log_tests.py + +import wandb +from math import nan + +# wandb project to log to (name is set per-run below) +init_kwargs = dict( + entity='trailab', + project='ForeSight', + name=None, +) + +# Add entries as: results_dict[''] = {} +results_dict = {} +results_dict['sparsedrive_r50_stage2_4gpu_nomap_novalmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3508, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.6027, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.7565, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.8333, 'img_bbox_NuScenes/car_trans_err': 0.3839, 'img_bbox_NuScenes/car_scale_err': 0.1462, 'img_bbox_NuScenes/car_orient_err': 0.0745, 'img_bbox_NuScenes/car_vel_err': 0.2122, 'img_bbox_NuScenes/car_attr_err': 0.1931, 'img_bbox_NuScenes/mATE': 0.559, 'img_bbox_NuScenes/mASE': 0.273, 'img_bbox_NuScenes/mAOE': 0.5665, 'img_bbox_NuScenes/mAVE': 0.2748, 'img_bbox_NuScenes/mAAE': 0.1843, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0822, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2492, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4667, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5558, 'img_bbox_NuScenes/truck_trans_err': 0.6161, 'img_bbox_NuScenes/truck_scale_err': 0.1997, 'img_bbox_NuScenes/truck_orient_err': 0.14, 'img_bbox_NuScenes/truck_vel_err': 0.1974, 'img_bbox_NuScenes/truck_attr_err': 0.1849, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0171, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.0804, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.1867, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.9051, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.5006, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.3696, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1487, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.3778, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.057, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.2654, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.4994, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.6913, 'img_bbox_NuScenes/bus_trans_err': 0.6795, 'img_bbox_NuScenes/bus_scale_err': 0.2195, 'img_bbox_NuScenes/bus_orient_err': 0.1293, 'img_bbox_NuScenes/bus_vel_err': 0.4734, 'img_bbox_NuScenes/bus_attr_err': 0.2621, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0295, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1291, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2746, 'img_bbox_NuScenes/trailer_trans_err': 0.9455, 'img_bbox_NuScenes/trailer_scale_err': 0.2619, 'img_bbox_NuScenes/trailer_orient_err': 0.6304, 'img_bbox_NuScenes/trailer_vel_err': 0.3321, 'img_bbox_NuScenes/trailer_attr_err': 0.037, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.362, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.5598, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.669, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.7042, 'img_bbox_NuScenes/barrier_trans_err': 0.3256, 'img_bbox_NuScenes/barrier_scale_err': 0.2842, 'img_bbox_NuScenes/barrier_orient_err': 0.1083, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1546, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.4048, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5362, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.6075, 'img_bbox_NuScenes/motorcycle_trans_err': 0.5019, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2554, 'img_bbox_NuScenes/motorcycle_orient_err': 0.8037, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3619, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2457, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2086, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4115, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4935, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.5589, 'img_bbox_NuScenes/bicycle_trans_err': 0.4162, 'img_bbox_NuScenes/bicycle_scale_err': 0.2588, 'img_bbox_NuScenes/bicycle_orient_err': 1.2732, 'img_bbox_NuScenes/bicycle_vel_err': 0.1427, 'img_bbox_NuScenes/bicycle_attr_err': 0.0065, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1827, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4434, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.66, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7615, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5612, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2874, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5691, 'img_bbox_NuScenes/pedestrian_vel_err': 0.3303, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1671, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.527, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6597, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7376, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.7733, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2547, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.316, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5210452731598937, 'img_bbox_NuScenes/mAP': 0.4135818961243956, 'img_bbox_NuScenes/amota': 0.37841917596535796, 'img_bbox_NuScenes/amotp': 1.254887412737341, 'img_bbox_NuScenes/recall': 0.48778598773767107, 'img_bbox_NuScenes/motar': 0.6749810391493056, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.34960712411451783, 'img_bbox_NuScenes/motp': 0.6241441130285599, 'img_bbox_NuScenes/mt': 2432.0, 'img_bbox_NuScenes/ml': 2342.0, 'img_bbox_NuScenes/faf': 45.385501647071784, 'img_bbox_NuScenes/tp': 60681.0, 'img_bbox_NuScenes/fp': 12701.0, 'img_bbox_NuScenes/fn': 40297.0, 'img_bbox_NuScenes/ids': 919.0, 'img_bbox_NuScenes/frag': 574.0, 'img_bbox_NuScenes/tid': 1.98160291062767, 'img_bbox_NuScenes/lgd': 2.5866146019633036, 'car_EPA': 0.4893063663779985, 'pedestrian_EPA': 0.4168936170212766, 'car_min_ade_err': 0.6188359283156746, 'car_min_fde_err': 0.9741884232379546, 'car_miss_rate_err': 0.1353851265872359, 'pedestrian_min_ade_err': 0.7137598055143394, 'pedestrian_min_fde_err': 1.0442730012917574, 'pedestrian_miss_rate_err': 0.14390649673648942, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0008248138858309378, 'L2': 0.5903816537724601} +results_dict['sparsedrive_r50_stage2_4gpu_nomap_notrainmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3442, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.5914, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.7521, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.8271, 'img_bbox_NuScenes/car_trans_err': 0.3811, 'img_bbox_NuScenes/car_scale_err': 0.1476, 'img_bbox_NuScenes/car_orient_err': 0.072, 'img_bbox_NuScenes/car_vel_err': 0.2106, 'img_bbox_NuScenes/car_attr_err': 0.1925, 'img_bbox_NuScenes/mATE': 0.5684, 'img_bbox_NuScenes/mASE': 0.2723, 'img_bbox_NuScenes/mAOE': 0.5279, 'img_bbox_NuScenes/mAVE': 0.2559, 'img_bbox_NuScenes/mAAE': 0.1894, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0966, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2541, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4585, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5452, 'img_bbox_NuScenes/truck_trans_err': 0.5826, 'img_bbox_NuScenes/truck_scale_err': 0.2002, 'img_bbox_NuScenes/truck_orient_err': 0.1221, 'img_bbox_NuScenes/truck_vel_err': 0.1876, 'img_bbox_NuScenes/truck_attr_err': 0.1921, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0159, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.1082, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.232, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.9415, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.4728, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.2952, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1314, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.4037, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.0683, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.2488, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.5065, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.6973, 'img_bbox_NuScenes/bus_trans_err': 0.6729, 'img_bbox_NuScenes/bus_scale_err': 0.2401, 'img_bbox_NuScenes/bus_orient_err': 0.1349, 'img_bbox_NuScenes/bus_vel_err': 0.4418, 'img_bbox_NuScenes/bus_attr_err': 0.256, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0152, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1115, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2447, 'img_bbox_NuScenes/trailer_trans_err': 1.026, 'img_bbox_NuScenes/trailer_scale_err': 0.2672, 'img_bbox_NuScenes/trailer_orient_err': 0.7522, 'img_bbox_NuScenes/trailer_vel_err': 0.2828, 'img_bbox_NuScenes/trailer_attr_err': 0.0403, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.3544, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.5654, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.6728, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.7101, 'img_bbox_NuScenes/barrier_trans_err': 0.3321, 'img_bbox_NuScenes/barrier_scale_err': 0.2847, 'img_bbox_NuScenes/barrier_orient_err': 0.115, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1725, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.3832, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5372, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.6004, 'img_bbox_NuScenes/motorcycle_trans_err': 0.5145, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2535, 'img_bbox_NuScenes/motorcycle_orient_err': 0.6521, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3426, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2553, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2015, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4043, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4961, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.5502, 'img_bbox_NuScenes/bicycle_trans_err': 0.4078, 'img_bbox_NuScenes/bicycle_scale_err': 0.254, 'img_bbox_NuScenes/bicycle_orient_err': 1.0339, 'img_bbox_NuScenes/bicycle_vel_err': 0.1277, 'img_bbox_NuScenes/bicycle_attr_err': 0.0091, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1789, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4351, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.6511, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7558, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5605, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2881, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5734, 'img_bbox_NuScenes/pedestrian_vel_err': 0.323, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1662, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.5255, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6562, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7554, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.7894, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2654, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.3147, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5250229358032298, 'img_bbox_NuScenes/mAP': 0.41282793609219165, 'img_bbox_NuScenes/amota': 0.37481117124367386, 'img_bbox_NuScenes/amotp': 1.243938826625982, 'img_bbox_NuScenes/recall': 0.5012763587171442, 'img_bbox_NuScenes/motar': 0.6331981367976215, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.3471572990666577, 'img_bbox_NuScenes/motp': 0.652721931604462, 'img_bbox_NuScenes/mt': 2425.0, 'img_bbox_NuScenes/ml': 2417.0, 'img_bbox_NuScenes/faf': 49.93009742722906, 'img_bbox_NuScenes/tp': 60614.0, 'img_bbox_NuScenes/fp': 13965.0, 'img_bbox_NuScenes/fn': 40432.0, 'img_bbox_NuScenes/ids': 851.0, 'img_bbox_NuScenes/frag': 554.0, 'img_bbox_NuScenes/tid': 1.8597660707921904, 'img_bbox_NuScenes/lgd': 2.312623196734337, 'car_EPA': 0.4765860609246604, 'pedestrian_EPA': 0.4042340425531915, 'car_min_ade_err': 0.6226734707457249, 'car_min_fde_err': 0.9728587197855352, 'car_miss_rate_err': 0.1323666881529005, 'pedestrian_min_ade_err': 0.7220416654547696, 'pedestrian_min_fde_err': 1.0578391945840033, 'pedestrian_miss_rate_err': 0.14206930953249303, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0009984588925565023, 'L2': 0.606941523651282} +results_dict['sparsedrive_r50_stage2_4gpu_bs24_notrainmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3508, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.5974, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.7558, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.831, 'img_bbox_NuScenes/car_trans_err': 0.3742, 'img_bbox_NuScenes/car_scale_err': 0.1472, 'img_bbox_NuScenes/car_orient_err': 0.0673, 'img_bbox_NuScenes/car_vel_err': 0.2134, 'img_bbox_NuScenes/car_attr_err': 0.1939, 'img_bbox_NuScenes/mATE': 0.5647, 'img_bbox_NuScenes/mASE': 0.273, 'img_bbox_NuScenes/mAOE': 0.5172, 'img_bbox_NuScenes/mAVE': 0.2548, 'img_bbox_NuScenes/mAAE': 0.1856, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0968, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2853, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4715, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5704, 'img_bbox_NuScenes/truck_trans_err': 0.5657, 'img_bbox_NuScenes/truck_scale_err': 0.2025, 'img_bbox_NuScenes/truck_orient_err': 0.1376, 'img_bbox_NuScenes/truck_vel_err': 0.1827, 'img_bbox_NuScenes/truck_attr_err': 0.1764, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0164, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.1006, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.2256, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.9397, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.4907, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.2775, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1276, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.3784, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.0578, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.2705, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.506, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.7047, 'img_bbox_NuScenes/bus_trans_err': 0.6669, 'img_bbox_NuScenes/bus_scale_err': 0.2114, 'img_bbox_NuScenes/bus_orient_err': 0.142, 'img_bbox_NuScenes/bus_vel_err': 0.4896, 'img_bbox_NuScenes/bus_attr_err': 0.264, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0188, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1256, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2585, 'img_bbox_NuScenes/trailer_trans_err': 1.0126, 'img_bbox_NuScenes/trailer_scale_err': 0.272, 'img_bbox_NuScenes/trailer_orient_err': 0.6611, 'img_bbox_NuScenes/trailer_vel_err': 0.1953, 'img_bbox_NuScenes/trailer_attr_err': 0.0356, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.3398, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.551, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.6489, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.6994, 'img_bbox_NuScenes/barrier_trans_err': 0.3263, 'img_bbox_NuScenes/barrier_scale_err': 0.2811, 'img_bbox_NuScenes/barrier_orient_err': 0.1107, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1709, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.3727, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5369, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.592, 'img_bbox_NuScenes/motorcycle_trans_err': 0.5122, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2601, 'img_bbox_NuScenes/motorcycle_orient_err': 0.732, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3777, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2651, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2191, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4086, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4922, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.5613, 'img_bbox_NuScenes/bicycle_trans_err': 0.4064, 'img_bbox_NuScenes/bicycle_scale_err': 0.2524, 'img_bbox_NuScenes/bicycle_orient_err': 0.9753, 'img_bbox_NuScenes/bicycle_vel_err': 0.1296, 'img_bbox_NuScenes/bicycle_attr_err': 0.0068, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1718, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4335, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.6497, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7572, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5675, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2907, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5514, 'img_bbox_NuScenes/pedestrian_vel_err': 0.3226, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1644, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.4908, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6467, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7377, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.7785, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2753, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.3223, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5267472552418652, 'img_bbox_NuScenes/mAP': 0.41255745286416134, 'img_bbox_NuScenes/amota': 0.3751674076861607, 'img_bbox_NuScenes/amotp': 1.2309562814926793, 'img_bbox_NuScenes/recall': 0.529213083568733, 'img_bbox_NuScenes/motar': 0.6298597506597094, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.3433148506844225, 'img_bbox_NuScenes/motp': 0.652450085163147, 'img_bbox_NuScenes/mt': 2494.0, 'img_bbox_NuScenes/ml': 2449.0, 'img_bbox_NuScenes/faf': 76.95633001721761, 'img_bbox_NuScenes/tp': 61285.0, 'img_bbox_NuScenes/fp': 16021.0, 'img_bbox_NuScenes/fn': 39960.0, 'img_bbox_NuScenes/ids': 652.0, 'img_bbox_NuScenes/frag': 577.0, 'img_bbox_NuScenes/tid': 1.6894107547316926, 'img_bbox_NuScenes/lgd': 2.249098392132382, 'ped_crossing': 0.5106424816863883, 'divider': 0.566774188397169, 'boundary': 0.5805887420239092, 'mAP_normal': 0.5526684707024888, 'car_EPA': 0.48456132584456185, 'pedestrian_EPA': 0.40117021276595743, 'car_min_ade_err': 0.6271592403240047, 'car_min_fde_err': 0.9811241473554299, 'car_miss_rate_err': 0.13298819295968337, 'pedestrian_min_ade_err': 0.7224249630153603, 'pedestrian_min_fde_err': 1.0487785643782717, 'pedestrian_miss_rate_err': 0.14350357990815366, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0013348961696869489, 'L2': 0.6096154629356332} +results_dict['sparsedrive_r50_stage2_4gpu_bs24_novalmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3538, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.6001, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.758, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.8323, 'img_bbox_NuScenes/car_trans_err': 0.3792, 'img_bbox_NuScenes/car_scale_err': 0.1473, 'img_bbox_NuScenes/car_orient_err': 0.069, 'img_bbox_NuScenes/car_vel_err': 0.2209, 'img_bbox_NuScenes/car_attr_err': 0.1944, 'img_bbox_NuScenes/mATE': 0.5547, 'img_bbox_NuScenes/mASE': 0.2744, 'img_bbox_NuScenes/mAOE': 0.5532, 'img_bbox_NuScenes/mAVE': 0.2691, 'img_bbox_NuScenes/mAAE': 0.1896, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0902, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2458, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4512, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5459, 'img_bbox_NuScenes/truck_trans_err': 0.5886, 'img_bbox_NuScenes/truck_scale_err': 0.2, 'img_bbox_NuScenes/truck_orient_err': 0.1141, 'img_bbox_NuScenes/truck_vel_err': 0.2054, 'img_bbox_NuScenes/truck_attr_err': 0.1899, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0195, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.0852, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.2121, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.8767, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.4977, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.5706, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1386, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.4074, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.036, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.257, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.4972, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.6937, 'img_bbox_NuScenes/bus_trans_err': 0.6888, 'img_bbox_NuScenes/bus_scale_err': 0.2271, 'img_bbox_NuScenes/bus_orient_err': 0.134, 'img_bbox_NuScenes/bus_vel_err': 0.4766, 'img_bbox_NuScenes/bus_attr_err': 0.2613, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0222, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1196, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2607, 'img_bbox_NuScenes/trailer_trans_err': 0.9515, 'img_bbox_NuScenes/trailer_scale_err': 0.2659, 'img_bbox_NuScenes/trailer_orient_err': 0.7065, 'img_bbox_NuScenes/trailer_vel_err': 0.264, 'img_bbox_NuScenes/trailer_attr_err': 0.0367, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.3474, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.561, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.6808, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.7168, 'img_bbox_NuScenes/barrier_trans_err': 0.3446, 'img_bbox_NuScenes/barrier_scale_err': 0.2799, 'img_bbox_NuScenes/barrier_orient_err': 0.1079, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1776, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.3563, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5189, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.5853, 'img_bbox_NuScenes/motorcycle_trans_err': 0.4979, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2571, 'img_bbox_NuScenes/motorcycle_orient_err': 0.7239, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3656, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2564, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2349, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4202, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4961, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.56, 'img_bbox_NuScenes/bicycle_trans_err': 0.396, 'img_bbox_NuScenes/bicycle_scale_err': 0.2673, 'img_bbox_NuScenes/bicycle_orient_err': 0.9788, 'img_bbox_NuScenes/bicycle_vel_err': 0.1505, 'img_bbox_NuScenes/bicycle_attr_err': 0.0061, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1824, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4457, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.6639, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7674, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5609, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2893, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5742, 'img_bbox_NuScenes/pedestrian_vel_err': 0.3311, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1647, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.5225, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6764, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7529, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.789, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2632, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.3128, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5225890076722248, 'img_bbox_NuScenes/mAP': 0.4133932201012164, 'img_bbox_NuScenes/amota': 0.3754069553264074, 'img_bbox_NuScenes/amotp': 1.2561417111667958, 'img_bbox_NuScenes/recall': 0.5258422088207545, 'img_bbox_NuScenes/motar': 0.6338878086993166, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.34343638941858706, 'img_bbox_NuScenes/motp': 0.6429838895199319, 'img_bbox_NuScenes/mt': 2495.0, 'img_bbox_NuScenes/ml': 2226.0, 'img_bbox_NuScenes/faf': 75.21906140713715, 'img_bbox_NuScenes/tp': 61706.0, 'img_bbox_NuScenes/fp': 15583.0, 'img_bbox_NuScenes/fn': 39092.0, 'img_bbox_NuScenes/ids': 1099.0, 'img_bbox_NuScenes/frag': 635.0, 'img_bbox_NuScenes/tid': 1.6866828505090832, 'img_bbox_NuScenes/lgd': 2.2873464775229038, 'ped_crossing': 0.48785861624040217, 'divider': 0.5796305207181903, 'boundary': 0.5912029467581585, 'mAP_normal': 0.552897361238917, 'car_EPA': 0.49216307445425117, 'pedestrian_EPA': 0.41306382978723405, 'car_min_ade_err': 0.6361001780131743, 'car_min_fde_err': 1.0011809885715843, 'car_miss_rate_err': 0.13309976946893648, 'pedestrian_min_ade_err': 0.729691258942758, 'pedestrian_min_fde_err': 1.0665211591121826, 'pedestrian_miss_rate_err': 0.14577341344157752, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0013348961502843953, 'L2': 0.6358533894850148} + +for key, value in results_dict.items(): + init_kwargs['name'] = key + wandb.init(**init_kwargs) + wandb.log({'val/' + k: v for k, v in value.items()}) + wandb.finish() diff --git a/tools/visualization/bev_render.py b/tools/visualization/bev_render.py new file mode 100644 index 0000000..35db053 --- /dev/null +++ b/tools/visualization/bev_render.py @@ -0,0 +1,408 @@ +import os +import numpy as np +import cv2 + +import matplotlib +import matplotlib.pyplot as plt + +from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners + +CMD_LIST = ['Turn Right', 'Turn Left', 'Go Straight'] +COLOR_VECTORS = ['cornflowerblue', 'royalblue', 'slategrey'] +SCORE_THRESH = 0.3 +MAP_SCORE_THRESH = 0.3 +color_mapping = np.asarray([ + [0, 0, 0], + [255, 179, 0], + [128, 62, 117], + [255, 104, 0], + [166, 189, 215], + [193, 0, 32], + [206, 162, 98], + [129, 112, 102], + [0, 125, 52], + [246, 118, 142], + [0, 83, 138], + [255, 122, 92], + [83, 55, 122], + [255, 142, 0], + [179, 40, 81], + [244, 200, 0], + [127, 24, 13], + [147, 170, 0], + [89, 51, 21], + [241, 58, 19], + [35, 44, 22], + [112, 224, 255], + [70, 184, 160], + [153, 0, 255], + [71, 255, 0], + [255, 0, 163], + [255, 204, 0], + [0, 255, 235], + [255, 0, 235], + [255, 0, 122], + [255, 245, 0], + [10, 190, 212], + [214, 255, 0], + [0, 204, 255], + [20, 0, 255], + [255, 255, 0], + [0, 153, 255], + [0, 255, 204], + [41, 255, 0], + [173, 0, 255], + [0, 245, 255], + [71, 0, 255], + [0, 255, 184], + [0, 92, 255], + [184, 255, 0], + [255, 214, 0], + [25, 194, 194], + [92, 0, 255], + [220, 220, 220], + [255, 9, 92], + [112, 9, 255], + [8, 255, 214], + [255, 184, 6], + [10, 255, 71], + [255, 41, 10], + [7, 255, 255], + [224, 255, 8], + [102, 8, 255], + [255, 61, 6], + [255, 194, 7], + [0, 255, 20], + [255, 8, 41], + [255, 5, 153], + [6, 51, 255], + [235, 12, 255], + [160, 150, 20], + [0, 163, 255], + [140, 140, 140], + [250, 10, 15], + [20, 255, 0], +]) / 255 + + +class BEVRender: + def __init__( + self, + plot_choices, + out_dir, + xlim = 40, + ylim = 40, + ): + self.plot_choices = plot_choices + self.xlim = xlim + self.ylim = ylim + self.gt_dir = os.path.join(out_dir, "bev_gt") + self.pred_dir = os.path.join(out_dir, "bev_pred") + os.makedirs(self.gt_dir, exist_ok=True) + os.makedirs(self.pred_dir, exist_ok=True) + + def reset_canvas(self): + plt.close() + self.fig, self.axes = plt.subplots(1, 1, figsize=(20, 20)) + self.axes.set_xlim(- self.xlim, self.xlim) + self.axes.set_ylim(- self.ylim, self.ylim) + self.axes.axis('off') + + def render( + self, + data, + result, + index, + ): + self.reset_canvas() + self.draw_detection_gt(data) + self.draw_motion_gt(data) + self.draw_map_gt(data) + self.draw_planning_gt(data) + self._render_sdc_car() + self._render_command(data) + self._render_legend() + save_path_gt = os.path.join(self.gt_dir, str(index).zfill(4) + '.jpg') + self.save_fig(save_path_gt) + + self.reset_canvas() + self.draw_detection_pred(result) + self.draw_track_pred(result) + self.draw_motion_pred(result) + self.draw_map_pred(result) + self.draw_planning_pred(data, result) + self._render_sdc_car() + self._render_command(data) + self._render_legend() + save_path_pred = os.path.join(self.pred_dir, str(index).zfill(4) + '.jpg') + self.save_fig(save_path_pred) + + return save_path_gt, save_path_pred + + def save_fig(self, filename): + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, + hspace=0, wspace=0) + plt.margins(0, 0) + plt.savefig(filename) + + def draw_detection_gt(self, data): + if not self.plot_choices['det']: + return + + for i in range(data['gt_labels_3d'].shape[0]): + label = data['gt_labels_3d'][i] + if label == -1: + continue + color = color_mapping[i % len(color_mapping)] + + # draw corners + corners = box3d_to_corners(data['gt_bboxes_3d'])[i, [0, 3, 7, 4, 0]] + x = corners[:, 0] + y = corners[:, 1] + self.axes.plot(x, y, color=color, linewidth=3, linestyle='-') + + # draw line to indicate forward direction + forward_center = np.mean(corners[2:4], axis=0) + center = np.mean(corners[0:4], axis=0) + x = [forward_center[0], center[0]] + y = [forward_center[1], center[1]] + self.axes.plot(x, y, color=color, linewidth=3, linestyle='-') + + def draw_detection_pred(self, result): + if not (self.plot_choices['draw_pred'] and self.plot_choices['det'] and "boxes_3d" in result): + return + + bboxes = result['boxes_3d'] + for i in range(result['labels_3d'].shape[0]): + score = result['scores_3d'][i] + if score < SCORE_THRESH: + continue + color = color_mapping[result['instance_ids'][i] % len(color_mapping)] + + # draw corners + corners = box3d_to_corners(bboxes)[i, [0, 3, 7, 4, 0]] + x = corners[:, 0] + y = corners[:, 1] + self.axes.plot(x, y, color=color, linewidth=3, linestyle='-') + + # draw line to indicate forward direction + forward_center = np.mean(corners[2:4], axis=0) + center = np.mean(corners[0:4], axis=0) + x = [forward_center[0], center[0]] + y = [forward_center[1], center[1]] + self.axes.plot(x, y, color=color, linewidth=3, linestyle='-') + + def draw_track_pred(self, result): + if not (self.plot_choices['draw_pred'] and self.plot_choices['track'] and "anchor_queue" in result): + return + + temp_bboxes = result["anchor_queue"] + period = result["period"] + bboxes = result['boxes_3d'] + for i in range(result['labels_3d'].shape[0]): + score = result['scores_3d'][i] + if score < SCORE_THRESH: + continue + color = color_mapping[result['instance_ids'][i] % len(color_mapping)] + center = bboxes[i, :3] + centers = [center] + for j in range(period[i]): + # draw corners + corners = box3d_to_corners(temp_bboxes[:, -1-j])[i, [0, 3, 7, 4, 0]] + x = corners[:, 0] + y = corners[:, 1] + self.axes.plot(x, y, color=color, linewidth=2, linestyle='-') + + # draw line to indicate forward direction + forward_center = np.mean(corners[2:4], axis=0) + center = np.mean(corners[0:4], axis=0) + x = [forward_center[0], center[0]] + y = [forward_center[1], center[1]] + self.axes.plot(x, y, color=color, linewidth=2, linestyle='-') + centers.append(center) + + centers = np.stack(centers) + xs = centers[:, 0] + ys = centers[:, 1] + self.axes.plot(xs, ys, color=color, linewidth=2, linestyle='-') + + def draw_motion_gt(self, data): + if not self.plot_choices['motion']: + return + + for i in range(data['gt_labels_3d'].shape[0]): + label = data['gt_labels_3d'][i] + if label == -1: + continue + color = color_mapping[i % len(color_mapping)] + vehicle_id_list = [0, 1, 2, 3, 4, 6, 7] + if label in vehicle_id_list: + dot_size = 150 + else: + dot_size = 25 + + center = data['gt_bboxes_3d'][i, :2] + masks = data['gt_agent_fut_masks'][i].astype(bool) + if masks[0] == 0: + continue + trajs = data['gt_agent_fut_trajs'][i][masks] + trajs = trajs.cumsum(axis=0) + center + trajs = np.concatenate([center.reshape(1, 2), trajs], axis=0) + + self._render_traj(trajs, traj_score=1.0, + colormap='winter', dot_size=dot_size) + + def draw_motion_pred(self, result, top_k=3): + if not (self.plot_choices['draw_pred'] and self.plot_choices['motion'] and "trajs_3d" in result): + return + + bboxes = result['boxes_3d'] + labels = result['labels_3d'] + for i in range(result['labels_3d'].shape[0]): + score = result['scores_3d'][i] + if score < SCORE_THRESH: + continue + label = labels[i] + vehicle_id_list = [0, 1, 2, 3, 4, 6, 7] + if label in vehicle_id_list: + dot_size = 150 + else: + dot_size = 25 + + traj_score = result['trajs_score'][i].numpy() + traj = result['trajs_3d'][i].numpy() + num_modes = len(traj_score) + center = bboxes[i, :2][None, None].repeat(num_modes, 1, 1).numpy() + traj = np.concatenate([center, traj], axis=1) + + sorted_ind = np.argsort(traj_score)[::-1] + sorted_traj = traj[sorted_ind, :, :2] + sorted_score = traj_score[sorted_ind] + norm_score = np.exp(sorted_score[0]) + + for j in range(top_k - 1, -1, -1): + viz_traj = sorted_traj[j] + traj_score = np.exp(sorted_score[j])/norm_score + self._render_traj(viz_traj, traj_score=traj_score, + colormap='winter', dot_size=dot_size) + + def draw_map_gt(self, data): + if not self.plot_choices['map']: + return + vectors = data['map_infos'] + for label, vector_list in vectors.items(): + color = COLOR_VECTORS[label] + for vector in vector_list: + pts = vector[:, :2] + x = np.array([pt[0] for pt in pts]) + y = np.array([pt[1] for pt in pts]) + self.axes.plot(x, y, color=color, linewidth=3, marker='o', linestyle='-', markersize=7) + + def draw_map_pred(self, result): + if not (self.plot_choices['draw_pred'] and self.plot_choices['map'] and "vectors" in result): + return + + for i in range(result['scores'].shape[0]): + score = result['scores'][i] + if score < MAP_SCORE_THRESH: + continue + color = COLOR_VECTORS[result['labels'][i]] + pts = result['vectors'][i] + x = pts[:, 0] + y = pts[:, 1] + plt.plot(x, y, color=color, linewidth=3, marker='o', linestyle='-', markersize=7) + + def draw_planning_gt(self, data): + if not self.plot_choices['planning']: + return + + # draw planning gt + masks = data['gt_ego_fut_masks'].astype(bool) + if masks[0] != 0: + plan_traj = data['gt_ego_fut_trajs'][masks] + cmd = data['gt_ego_fut_cmd'] + plan_traj[abs(plan_traj) < 0.01] = 0.0 + plan_traj = plan_traj.cumsum(axis=0) + plan_traj = np.concatenate((np.zeros((1, plan_traj.shape[1])), plan_traj), axis=0) + self._render_traj(plan_traj, traj_score=1.0, + colormap='autumn', dot_size=50) + + def draw_planning_pred(self, data, result, top_k=3): + if not (self.plot_choices['draw_pred'] and self.plot_choices['planning'] and "planning" in result): + return + + if self.plot_choices['track'] and "ego_anchor_queue" in result: + ego_temp_bboxes = result["ego_anchor_queue"] + ego_period = result["ego_period"] + for j in range(ego_period[0]): + # draw corners + corners = box3d_to_corners(ego_temp_bboxes[:, -1-j])[0, [0, 3, 7, 4, 0]] + x = corners[:, 0] + y = corners[:, 1] + self.axes.plot(x, y, color='mediumseagreen', linewidth=2, linestyle='-') + + # draw line to indicate forward direction + forward_center = np.mean(corners[2:4], axis=0) + center = np.mean(corners[0:4], axis=0) + x = [forward_center[0], center[0]] + y = [forward_center[1], center[1]] + self.axes.plot(x, y, color='mediumseagreen', linewidth=2, linestyle='-') + # import ipdb; ipdb.set_trace() + plan_trajs = result['planning'].cpu().numpy() + num_cmd = len(CMD_LIST) + num_mode = plan_trajs.shape[1] + plan_trajs = np.concatenate((np.zeros((num_cmd, num_mode, 1, 2)), plan_trajs), axis=2) + plan_score = result['planning_score'].cpu().numpy() + + cmd = data['gt_ego_fut_cmd'].argmax() + plan_trajs = plan_trajs[cmd] + plan_score = plan_score[cmd] + + sorted_ind = np.argsort(plan_score)[::-1] + sorted_traj = plan_trajs[sorted_ind, :, :2] + sorted_score = plan_score[sorted_ind] + norm_score = np.exp(sorted_score[0]) + + for j in range(top_k - 1, -1, -1): + viz_traj = sorted_traj[j] + traj_score = np.exp(sorted_score[j]) / norm_score + self._render_traj(viz_traj, traj_score=traj_score, + colormap='autumn', dot_size=50) + + def _render_traj( + self, + future_traj, + traj_score=1, + colormap='winter', + points_per_step=20, + dot_size=25 + ): + total_steps = (len(future_traj) - 1) * points_per_step + 1 + dot_colors = matplotlib.colormaps[colormap]( + np.linspace(0, 1, total_steps))[:, :3] + dot_colors = dot_colors * traj_score + \ + (1 - traj_score) * np.ones_like(dot_colors) + total_xy = np.zeros((total_steps, 2)) + for i in range(total_steps - 1): + unit_vec = future_traj[i // points_per_step + + 1] - future_traj[i // points_per_step] + total_xy[i] = (i / points_per_step - i // points_per_step) * \ + unit_vec + future_traj[i // points_per_step] + total_xy[-1] = future_traj[-1] + self.axes.scatter( + total_xy[:, 0], total_xy[:, 1], c=dot_colors, s=dot_size) + + def _render_sdc_car(self): + sdc_car_png = cv2.imread('resources/sdc_car.png') + sdc_car_png = cv2.cvtColor(sdc_car_png, cv2.COLOR_BGR2RGB) + im = self.axes.imshow(sdc_car_png, extent=(-1, 1, -2, 2)) + im.set_zorder(2) + + def _render_legend(self): + legend = cv2.imread('resources/legend.png') + legend = cv2.cvtColor(legend, cv2.COLOR_BGR2RGB) + self.axes.imshow(legend, extent=(15, 40, -40, -30)) + + def _render_command(self, data): + cmd = data['gt_ego_fut_cmd'].argmax() + self.axes.text(-38, -38, CMD_LIST[cmd], fontsize=60) \ No newline at end of file diff --git a/tools/visualization/cam_render.py b/tools/visualization/cam_render.py new file mode 100644 index 0000000..be05651 --- /dev/null +++ b/tools/visualization/cam_render.py @@ -0,0 +1,271 @@ +import os +import numpy as np +import cv2 +from PIL import Image + +import matplotlib +import matplotlib.pyplot as plt +from pyquaternion import Quaternion +from nuscenes.utils.data_classes import Box as NuScenesBox +from nuscenes.utils.geometry_utils import view_points, box_in_image, BoxVisibility, transform_matrix + +from tools.visualization.bev_render import ( + color_mapping, + SCORE_THRESH, + MAP_SCORE_THRESH, + CMD_LIST +) + + +CAM_NAMES_NUSC = [ + 'CAM_FRONT_LEFT', + 'CAM_FRONT', + 'CAM_FRONT_RIGHT', + 'CAM_BACK_RIGHT', + 'CAM_BACK', + 'CAM_BACK_LEFT', +] +CAM_NAMES_NUSC_converter = [ + 'CAM_FRONT', + 'CAM_FRONT_RIGHT', + 'CAM_FRONT_LEFT', + 'CAM_BACK', + 'CAM_BACK_LEFT', + 'CAM_BACK_RIGHT', +] + +class CamRender: + def __init__( + self, + plot_choices, + out_dir, + ): + self.plot_choices = plot_choices + self.pred_dir = os.path.join(out_dir, "cam_pred") + os.makedirs(self.pred_dir, exist_ok=True) + + def reset_canvas(self): + plt.close() + plt.gca().set_axis_off() + plt.axis('off') + self.fig, self.axes = plt.subplots(2, 3, figsize=(160 /3 , 20)) + plt.tight_layout() + + def render( + self, + data, + result, + index, + ): + self.reset_canvas() + self.render_image_data(data, index) + self.draw_detection_pred(data, result) + self.draw_motion_pred(data, result) + self.draw_planning_pred(data, result) + save_path = os.path.join(self.pred_dir, str(index).zfill(4) + '.jpg') + self.save_fig(save_path) + return save_path + + def load_image(self, data_path, cam): + """Update the axis of the plot with the provided image.""" + image = np.array(Image.open(data_path)) + font = cv2.FONT_HERSHEY_SIMPLEX + org = (50, 60) + fontScale = 2 + color = (0, 0, 0) + thickness = 4 + return cv2.putText(image, cam, org, font, fontScale, color, thickness, cv2.LINE_AA) + + def update_image(self, image, index, cam): + """Render image data for each camera.""" + ax = self.get_axis(index) + ax.imshow(image) + plt.axis('off') + ax.axis('off') + ax.grid(False) + + def get_axis(self, index): + """Retrieve the corresponding axis based on the index.""" + return self.axes[index//3, index % 3] + + def save_fig(self, filename): + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, + hspace=0, wspace=0) + plt.margins(0, 0) + plt.savefig(filename) + + def render_image_data(self, data, index): + """Load and annotate image based on the provided path.""" + for i, cam in enumerate(CAM_NAMES_NUSC): + idx = CAM_NAMES_NUSC_converter.index(cam) + img_path = data['img_filename'][idx] + image = self.load_image(img_path, cam) + self.update_image(image, i, cam) + + def draw_detection_pred(self, data, result): + if not (self.plot_choices['draw_pred'] and self.plot_choices['det'] and "boxes_3d" in result): + return + + bboxes = result['boxes_3d'].numpy() + for j, cam in enumerate(CAM_NAMES_NUSC): + idx = CAM_NAMES_NUSC_converter.index(cam) + cam_intrinsic = data['cam_intrinsic'][idx] + lidar2cam = data['lidar2cam'] + extrinsic = lidar2cam[idx] + trans = extrinsic[3, :3] + rot = Quaternion(matrix=extrinsic[:3, :3]).inverse + imsize = (1600, 900) + + for i in range(result['labels_3d'].shape[0]): + score = result['scores_3d'][i] + if score < SCORE_THRESH: + continue + color = color_mapping[result['instance_ids'][i] % len(color_mapping)] + + center = bboxes[i, 0 : 3] + box_dims = bboxes[i, 3 : 6] + nusc_dims = box_dims[..., [1, 0, 2]] + quat = Quaternion(axis=[0, 0, 1], radians=bboxes[i, 6]) + box = NuScenesBox( + center, + nusc_dims, + quat + ) + box.rotate(rot) + box.translate(trans) + if box_in_image(box, cam_intrinsic, imsize): + box.render( + self.axes[j // 3, j % 3], + view=cam_intrinsic, + normalize=True, + colors=(color, color, color), + linewidth=4, + ) + + self.axes[j//3, j % 3].set_xlim(0, imsize[0]) + self.axes[j//3, j % 3].set_ylim(imsize[1], 0) + + def draw_motion_pred(self, data, result, points_per_step=10): + if not (self.plot_choices['draw_pred'] and self.plot_choices['motion'] and "trajs_3d" in result): + return + + bboxes = result['boxes_3d'].numpy() + for j, cam in enumerate(CAM_NAMES_NUSC): + idx = CAM_NAMES_NUSC_converter.index(cam) + cam_intrinsic = data['cam_intrinsic'][idx] + lidar2cam = data['lidar2cam'] + extrinsic = lidar2cam[idx] + trans = extrinsic[3, :3] + rot = Quaternion(matrix=extrinsic[:3, :3]).inverse + imsize = (1600, 900) + + for i in range(result['labels_3d'].shape[0]): + score = result['scores_3d'][i] + if score < SCORE_THRESH: + continue + color = color_mapping[result['instance_ids'][i] % len(color_mapping)] + + traj_score = result['trajs_score'][i].numpy() + traj = result['trajs_3d'][i].numpy() + + mode_idx = traj_score.argmax() + traj = traj[mode_idx] + origin = bboxes[i, :2][None] + traj = np.concatenate([origin, traj], axis=0) + traj_expand = np.ones((traj.shape[0], 1)) + traj_expand[:] = bboxes[i, 2] - bboxes[i, 5] / 2 + traj = np.concatenate([traj, traj_expand], axis=1) + + center = bboxes[i, 0 : 3] + box_dims = bboxes[i, 3 : 6] + nusc_dims = box_dims[..., [1, 0, 2]] + quat = Quaternion(axis=[0, 0, 1], radians=bboxes[i, 6]) + box = NuScenesBox( + center, + nusc_dims, + quat + ) + box.rotate(rot) + box.translate(trans) + if not box_in_image(box, cam_intrinsic, imsize): + continue + traj_points = traj @ extrinsic[:3, :3] + trans + self._render_traj(traj_points, cam_intrinsic, j, color=color, s=15) + + + def draw_planning_pred(self, data, result): + if not (self.plot_choices['draw_pred'] and self.plot_choices['planning'] and "planning" in result): + return + # for j, cam in enumerate(CAM_NAMES_NUSC[1]): + # idx = CAM_NAMES_NUSC_converter.index(cam) + # cam_intrinsic = data['cam_intrinsic'][idx] + # lidar2cam = data['lidar2cam'] + # extrinsic = lidar2cam[idx] + # trans = extrinsic[3, :3] + # rot = Quaternion(matrix=extrinsic[:3, :3]).inverse + # imsize = (1600, 900) + + # plan_trajs = result['planning'][0].cpu().numpy() + # plan_trajs = plan_trajs.reshape(3, -1, 6, 2) + # num_cmd = len(CMD_LIST) + # num_mode = plan_trajs.shape[1] + # plan_trajs = np.concatenate((np.zeros((num_cmd, num_mode, 1, 2)), plan_trajs), axis=2) + # plan_trajs = plan_trajs.cumsum(axis=-2) + # plan_score = result['planning_score'][0].cpu().numpy() + # plan_score = plan_score.reshape(3, -1) + + # cmd = data['gt_ego_fut_cmd'].argmax() + # plan_trajs = plan_trajs[cmd] + # plan_score = plan_score[cmd] + + # mode_idx = plan_score.argmax() + # plan_traj = plan_trajs[mode_idx] + # traj_expand = np.ones((plan_traj.shape[0], 1)) * -2 + # # traj_expand[:] = bboxes[i, 2] - bboxes[i, 5] / 2 + # plan_traj = np.concatenate([plan_traj, traj_expand], axis=1) + + # traj_points = plan_traj @ extrinsic[:3, :3] + trans + # self._render_traj(traj_points, cam_intrinsic, j) + + idx = 0 ## front camera + cam_intrinsic = data['cam_intrinsic'][idx] + lidar2cam = data['lidar2cam'] + extrinsic = lidar2cam[idx] + trans = extrinsic[3, :3] + rot = Quaternion(matrix=extrinsic[:3, :3]).inverse + # plan_trajs = result['planning'][0].cpu().numpy() + # plan_trajs = plan_trajs.reshape(3, -1, 6, 2) + # num_cmd = len(CMD_LIST) + # num_mode = plan_trajs.shape[1] + # plan_trajs = np.concatenate((np.zeros((num_cmd, num_mode, 1, 2)), plan_trajs), axis=2) + # plan_trajs = plan_trajs.cumsum(axis=-2) + # plan_score = result['planning_score'][0].cpu().numpy() + # plan_score = plan_score.reshape(3, -1) + + # cmd = data['gt_ego_fut_cmd'].argmax() + # plan_trajs = plan_trajs[cmd] + # plan_score = plan_score[cmd] + + # mode_idx = plan_score.argmax() + # plan_traj = plan_trajs[mode_idx] + plan_traj = result["final_planning"] + plan_traj = np.concatenate((np.zeros((1, 2)), plan_traj), axis=0) + traj_expand = np.ones((plan_traj.shape[0], 1)) * -1.8 + plan_traj = np.concatenate([plan_traj, traj_expand], axis=1) + + traj_points = plan_traj @ extrinsic[:3, :3] + trans + self._render_traj(traj_points, cam_intrinsic, j=1) + + def _render_traj(self, traj_points, cam_intrinsic, j, color=(1, 0.5, 0), s=150, points_per_step=10): + total_steps = (len(traj_points)-1) * points_per_step + 1 + total_xy = np.zeros((total_steps, 3)) + for k in range(total_steps-1): + unit_vec = traj_points[k//points_per_step + + 1] - traj_points[k//points_per_step] + total_xy[k] = (k/points_per_step - k//points_per_step) * \ + unit_vec + traj_points[k//points_per_step] + in_range_mask = total_xy[:, 2] > 0.1 + traj_points = view_points( + total_xy.T, cam_intrinsic, normalize=True)[:2, :] + traj_points = traj_points[:2, in_range_mask] + self.axes[j // 3, j % 3].scatter(traj_points[0], traj_points[1], color=color, s=s) \ No newline at end of file diff --git a/tools/visualization/visualize.py b/tools/visualization/visualize.py new file mode 100644 index 0000000..591908a --- /dev/null +++ b/tools/visualization/visualize.py @@ -0,0 +1,110 @@ +import os +import glob +import argparse +from tqdm import tqdm + +import cv2 +import numpy as np +from PIL import Image + +import mmcv +from mmcv import Config +from mmdet.datasets import build_dataset + +from tools.visualization.bev_render import BEVRender +from tools.visualization.cam_render import CamRender + +plot_choices = dict( + draw_pred = True, # True: draw gt and pred; False: only draw gt + det = True, + track = True, # True: draw history tracked boxes + motion = True, + map = True, + planning = True, +) +START = 0 +END = 81 +INTERVAL = 1 + + +class Visualizer: + def __init__( + self, + args, + plot_choices, + ): + self.out_dir = args.out_dir + self.combine_dir = os.path.join(self.out_dir, 'combine') + os.makedirs(self.combine_dir, exist_ok=True) + + cfg = Config.fromfile(args.config) + self.dataset = build_dataset(cfg.data.val) + self.results = mmcv.load(args.result_path) + self.bev_render = BEVRender(plot_choices, self.out_dir) + self.cam_render = CamRender(plot_choices, self.out_dir) + + def add_vis(self, index): + data = self.dataset.get_data_info(index) + result = self.results[index]['img_bbox'] + + bev_gt_path, bev_pred_path = self.bev_render.render(data, result, index) + cam_pred_path = self.cam_render.render(data, result, index) + self.combine(bev_gt_path, bev_pred_path, cam_pred_path, index) + + def combine(self, bev_gt_path, bev_pred_path, cam_pred_path, index): + bev_gt = cv2.imread(bev_gt_path) + bev_image = cv2.imread(bev_pred_path) + cam_image = cv2.imread(cam_pred_path) + merge_image = cv2.hconcat([cam_image, bev_image, bev_gt]) + save_path = os.path.join(self.combine_dir, str(index).zfill(4) + '.jpg') + cv2.imwrite(save_path, merge_image) + + def image2video(self, fps=12, downsample=4): + imgs_path = glob.glob(os.path.join(self.combine_dir, '*.jpg')) + imgs_path = sorted(imgs_path) + img_array = [] + for img_path in tqdm(imgs_path): + img = cv2.imread(img_path) + height, width, channel = img.shape + img = cv2.resize(img, (width//downsample, height // + downsample), interpolation=cv2.INTER_AREA) + height, width, channel = img.shape + size = (width, height) + img_array.append(img) + out_path = os.path.join(self.out_dir, 'video.mp4') + out = cv2.VideoWriter( + out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size) + for i in range(len(img_array)): + out.write(img_array[i]) + out.release() + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Visualize groundtruth and results') + parser.add_argument('config', help='config file path') + parser.add_argument('--result-path', + default=None, + help='prediction result to visualize' + 'If submission file is not provided, only gt will be visualized') + parser.add_argument( + '--out-dir', + default='vis', + help='directory where visualize results will be saved') + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + visualizer = Visualizer(args, plot_choices) + + for idx in tqdm(range(START, END, INTERVAL)): + if idx > len(visualizer.results): + break + visualizer.add_vis(idx) + + visualizer.image2video() + +if __name__ == '__main__': + main() \ No newline at end of file