diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..2e7c543
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,24 @@
+*.pyc
+*.npy
+*.pth
+*.whl
+*.swp
+*.sif
+
+wandb/
+data/
+ckpt/
+work_dirs*/
+dist_test/
+vis/
+val/
+lib/
+logs/
+
+*.egg-info
+build/
+__pycache__/
+*.so
+
+job_scripts/
+temp_ops/
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
index 261eeb9..5c09177 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,201 +1,21 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright [yyyy] [name of copyright owner]
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
+MIT License
+
+Copyright (c) 2024 swc-17
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
deleted file mode 100644
index b1f1ebe..0000000
--- a/README.md
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
ForeSight
-[ICCV2025] ForeSight: Multi-View Streaming Joint Object Detection and Trajectory Forecasting
-
-
-[](https://arxiv.org/abs/2508.07089)
-
-Code coming soon.
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000..1f3eb07
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,22 @@
+FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel
+ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6+PTX" \
+ TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \
+ CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \
+ FORCE_CUDA="1"
+
+# Install the MMDetection3D required packages
+RUN apt-get update \
+ && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install StreamPETR required packages
+WORKDIR /workspace/ForeSight/
+COPY requirement.txt .
+COPY projects/ projects/
+RUN pip install mmcv-full==1.7.1 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.13.0/index.html
+RUN pip install --no-cache-dir -r requirement.txt
+RUN cd projects/mmdet3d_plugin/ops && python setup.py develop
+RUN git config --global --add safe.directory '*'
+
+WORKDIR /workspace/ForeSight/
\ No newline at end of file
diff --git a/docs/quick_start.md b/docs/quick_start.md
new file mode 100644
index 0000000..0ad0e8f
--- /dev/null
+++ b/docs/quick_start.md
@@ -0,0 +1,64 @@
+# Quick Start
+
+### Set up a new virtual environment
+```bash
+conda create -n sparsedrive python=3.8 -y
+conda activate sparsedrive
+```
+
+### Install dependency packpages
+```bash
+sparsedrive_path="path/to/sparsedrive"
+cd ${sparsedrive_path}
+pip3 install --upgrade pip
+pip3 install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116
+pip3 install -r requirement.txt
+```
+
+### Compile the deformable_aggregation CUDA op
+```bash
+cd projects/mmdet3d_plugin/ops
+python3 setup.py develop
+cd ../../../
+```
+
+### Prepare the data
+Download the [NuScenes dataset](https://www.nuscenes.org/nuscenes#download) and CAN bus expansion, put CAN bus expansion in /path/to/nuscenes, create symbolic links.
+```bash
+cd ${sparsedrive_path}
+mkdir data
+ln -s path/to/nuscenes ./data/nuscenes
+```
+
+Pack the meta-information and labels of the dataset, and generate the required pkl files to data/infos. Note that we also generate map_annos in data_converter, with a roi_size of (30, 60) as default, if you want a different range, you can modify roi_size in tools/data_converter/nuscenes_converter.py.
+```bash
+sh scripts/create_data.sh
+```
+
+### Generate anchors by K-means
+Gnerated anchors are saved to data/kmeans and can be visualized in vis/kmeans.
+```bash
+sh scripts/kmeans.sh
+```
+
+
+### Download pre-trained weights
+Download the required backbone [pre-trained weights](https://download.pytorch.org/models/resnet50-19c8e357.pth).
+```bash
+mkdir ckpt
+wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O ckpt/resnet50-19c8e357.pth
+```
+
+### Commence training and testing
+```bash
+# train
+sh scripts/train.sh
+
+# test
+sh scripts/test.sh
+```
+
+### Visualization
+```
+sh scripts/visualize.sh
+```
diff --git a/projects/configs/sparsedrive_r101_stage1_4gpu.py b/projects/configs/sparsedrive_r101_stage1_4gpu.py
new file mode 100644
index 0000000..b6d5212
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage1_4gpu.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 80
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage1_4gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage1_8gpu.py b/projects/configs/sparsedrive_r101_stage1_8gpu.py
new file mode 100644
index 0000000..3985938
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage1_8gpu.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 80
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage1_8gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage1_8gpu_noflash.py b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash.py
new file mode 100644
index 0000000..d17cd5f
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 80
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage1_8gpu_noflash',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage1_8gpu_noflash_dn.py b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash_dn.py
new file mode 100644
index 0000000..f84de5f
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage1_8gpu_noflash_dn.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 80
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage1_8gpu_noflash_dn',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=5,
+ num_temp_dn_groups=3,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage2_4gpu.py b/projects/configs/sparsedrive_r101_stage2_4gpu.py
new file mode 100644
index 0000000..54c06d4
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage2_4gpu.py
@@ -0,0 +1,728 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage2_4gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_r101_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage2_4gpu_nomap.py b/projects/configs/sparsedrive_r101_stage2_4gpu_nomap.py
new file mode 100644
index 0000000..73d381f
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage2_4gpu_nomap.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage2_4gpu_nomap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_r101_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage2_8gpu.py b/projects/configs/sparsedrive_r101_stage2_8gpu.py
new file mode 100644
index 0000000..0fb3daa
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage2_8gpu.py
@@ -0,0 +1,728 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage2_8gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_r101_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage2_8gpu_noflash.py b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash.py
new file mode 100644
index 0000000..1a1a203
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash.py
@@ -0,0 +1,728 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage2_8gpu_noflash',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_r101_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r101_stage2_8gpu_noflash_dn.py b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash_dn.py
new file mode 100644
index 0000000..34da3ea
--- /dev/null
+++ b/projects/configs/sparsedrive_r101_stage2_8gpu_noflash_dn.py
@@ -0,0 +1,728 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 32
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r101_stage2_8gpu_noflash_dn',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (1408, 512)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [8, 16, 32, 64]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=101,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=True,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=False),
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='ckpt/cascade_mask_rcnn_r101_fpn_1x_nuim_20201024_134804-45215b1e.pth',
+ prefix='backbone.')),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=1, # skip stride-4 so depth branch and gt_depth align at [8, 16, 32]
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=5,
+ num_temp_dn_groups=3,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.80, 0.94),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_r101_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_4gpu.py b/projects/configs/sparsedrive_r50_stage1_4gpu.py
new file mode 100644
index 0000000..aee18a1
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_4gpu.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_4gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu.py b/projects/configs/sparsedrive_r50_stage1_8gpu.py
new file mode 100644
index 0000000..fdc93cb
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash.py
new file mode 100644
index 0000000..475b82a
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn.py
new file mode 100644
index 0000000..4bf69d4
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_dn',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=5,
+ num_temp_dn_groups=3,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug.py
new file mode 100644
index 0000000..e8859c4
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_dn_rotaug',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=5,
+ num_temp_dn_groups=3,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [-0.3925, 0.3925],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap.py
new file mode 100644
index 0000000..1324b8f
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap.py
@@ -0,0 +1,742 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_gtdetmap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=False,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap_deform.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap_deform.py
new file mode 100644
index 0000000..3ddf06b
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_gtdetmap_deform.py
@@ -0,0 +1,768 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_gtdetmap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="add",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=False,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap.py
new file mode 100644
index 0000000..aac93f2
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap.py
@@ -0,0 +1,722 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_nomap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn.py
new file mode 100644
index 0000000..9896ead
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn.py
@@ -0,0 +1,722 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_nomap_dn',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=5,
+ num_temp_dn_groups=3,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug.py
new file mode 100644
index 0000000..cacdb33
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug.py
@@ -0,0 +1,722 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_nomap_dn_rotaug',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=5,
+ num_temp_dn_groups=3,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [-0.3925, 0.3925],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug.py b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug.py
new file mode 100644
index 0000000..25afc6f
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug.py
@@ -0,0 +1,722 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage1_8gpu_noflash_nomap_rotaug',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [-0.3925, 0.3925],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_1gpu.py b/projects/configs/sparsedrive_r50_stage2_1gpu.py
new file mode 100644
index 0000000..fdaf380
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_1gpu.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 8
+num_gpus = 1
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_1gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=4,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu.py b/projects/configs/sparsedrive_r50_stage2_4gpu.py
new file mode 100644
index 0000000..0259073
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24.py
new file mode 100644
index 0000000..7feb7bd
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_gtdetmap.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_gtdetmap.py
new file mode 100644
index 0000000..2c6fe7b
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_gtdetmap.py
@@ -0,0 +1,745 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_gtdetmap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ num_map_classes=num_map_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=False,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap.py
new file mode 100644
index 0000000..4aed9f5
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_nomap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead.py
new file mode 100644
index 0000000..378c29c
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead.py
@@ -0,0 +1,743 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfhtraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfhtraineval.py
new file mode 100644
index 0000000..17466ac
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfhtraineval.py
@@ -0,0 +1,746 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_nomap_sephead_occfheval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_occ.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfheval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfheval.py
new file mode 100644
index 0000000..b70c948
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfheval.py
@@ -0,0 +1,728 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_occfheval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_cvocc.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_cvocc.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval.py
new file mode 100644
index 0000000..d0b45b6
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval.py
@@ -0,0 +1,729 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_occfhtraineval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_cvocc.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_cvocc.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfleval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfleval.py
new file mode 100644
index 0000000..7a7b838
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfleval.py
@@ -0,0 +1,728 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_occfleval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_occ.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_occ.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfltraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfltraineval.py
new file mode 100644
index 0000000..24c3b45
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occfltraineval.py
@@ -0,0 +1,729 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_occfltraineval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_occ.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_occ.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occpeval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occpeval.py
new file mode 100644
index 0000000..0ed163a
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occpeval.py
@@ -0,0 +1,728 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_occpeval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occptraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occptraineval.py
new file mode 100644
index 0000000..8674701
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_occptraineval.py
@@ -0,0 +1,729 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_occptraineval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly.py
new file mode 100644
index 0000000..cd91b47
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly.py
@@ -0,0 +1,732 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ca.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ca.py
new file mode 100644
index 0000000..88ceb84
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ca.py
@@ -0,0 +1,549 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_ca',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+# DDP: backbone/neck receive gradients but CV motion head has no trainable
+# parameters, so mark unused parameters to avoid DDP sync errors.
+find_unused_parameters = True
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ motion_plan_head=dict(
+ type='KinematicMotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_acceleration=True,
+ use_turn_rate=False,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra.py
new file mode 100644
index 0000000..d74cb63
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra.py
@@ -0,0 +1,549 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_ctra',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+# DDP: backbone/neck receive gradients but CV motion head has no trainable
+# parameters, so mark unused parameters to avoid DDP sync errors.
+find_unused_parameters = True
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ motion_plan_head=dict(
+ type='KinematicMotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_acceleration=True,
+ use_turn_rate=True,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv.py
new file mode 100644
index 0000000..5a0e4d8
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv.py
@@ -0,0 +1,549 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_ctrv',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+# DDP: backbone/neck receive gradients but CV motion head has no trainable
+# parameters, so mark unused parameters to avoid DDP sync errors.
+find_unused_parameters = True
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ motion_plan_head=dict(
+ type='KinematicMotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_acceleration=False,
+ use_turn_rate=True,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_cv.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_cv.py
new file mode 100644
index 0000000..cc52505
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_cv.py
@@ -0,0 +1,549 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_cv',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+# DDP: backbone/neck receive gradients but CV motion head has no trainable
+# parameters, so mark unused parameters to avoid DDP sync errors.
+find_unused_parameters = True
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ motion_plan_head=dict(
+ type='KinematicMotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_acceleration=False,
+ use_turn_rate=False,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_deform.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_deform.py
new file mode 100644
index 0000000..82a2691
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_deform.py
@@ -0,0 +1,762 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_deform',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ # Deformable cross-attention to sensor (image) features inserted
+ # after the GNN self-attention block in each of the 3 decoder
+ # iterations. residual_mode="add" keeps embed_dims unchanged so
+ # the existing AsymmetricFFN (in_channels=embed_dims) is compatible.
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="add",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand.py
new file mode 100644
index 0000000..a14fe94
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand.py
@@ -0,0 +1,731 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform.py
new file mode 100644
index 0000000..91c1a9a
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform.py
@@ -0,0 +1,761 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_initrand_deform',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ # Deformable cross-attention to sensor (image) features inserted
+ # after the GNN self-attention block in each of the 3 decoder
+ # iterations. residual_mode="add" keeps embed_dims unchanged so
+ # the existing AsymmetricFFN (in_channels=embed_dims) is compatible.
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="add",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_refine.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_refine.py
new file mode 100644
index 0000000..52d18e0
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_predonly_refine.py
@@ -0,0 +1,730 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_predonly_refine',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=False,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="GTSparseDriveHead",
+ task_config=task_config,
+ num_classes=num_classes,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ "refine",
+ ] * 3
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=False,
+ with_tracking=False,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2.py
new file mode 100644
index 0000000..0a2f882
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_pt2',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1_dn.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead.py
new file mode 100644
index 0000000..71f8dec
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead.py
@@ -0,0 +1,745 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_pt2_sephead',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1_dn.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occfltrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occfltrainval.py
new file mode 100644
index 0000000..ea91cbb
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occfltrainval.py
@@ -0,0 +1,748 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=False,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_occ.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_occ.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval.py
new file mode 100644
index 0000000..96806a9
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval.py
@@ -0,0 +1,748 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_sephead_occptrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso.py
new file mode 100644
index 0000000..aac9b01
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso.py
@@ -0,0 +1,757 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_supervise_all=True,
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=False,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=False,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.2,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval.py
new file mode 100644
index 0000000..f3f6aa6
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval.py
@@ -0,0 +1,759 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_sepheadobsiso_occptrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_supervise_all=True,
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=False,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=False,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.2,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval.py
new file mode 100644
index 0000000..f15be14
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval.py
@@ -0,0 +1,758 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_sepheadvis_occptrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.2,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval.py
new file mode 100644
index 0000000..3a00736
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval.py
@@ -0,0 +1,759 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occfltrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_supervise_all=True,
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.4,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_occ.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_occ.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval.py
new file mode 100644
index 0000000..23c85e4
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval.py
@@ -0,0 +1,759 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_sepheadvisobsiso_occptrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_supervise_all=True,
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.2,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_validflag.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_validflag.py
new file mode 100644
index 0000000..521d456
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_validflag.py
@@ -0,0 +1,729 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_validflag',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_valid_flag=True
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ use_valid_flag=True
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ use_valid_flag=True
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval.py
new file mode 100644
index 0000000..6463489
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval.py
@@ -0,0 +1,738 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_vishead_occfhtraineval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.4,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_cvocc.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_cvocc.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_cvocc.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval.py
new file mode 100644
index 0000000..6afde0f
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval.py
@@ -0,0 +1,738 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 24
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_bs24_vishead_occptraineval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.2,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=1.5e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv2.py b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv2.py
new file mode 100644
index 0000000..7a5da18
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv2.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_maplrdiv2',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=5.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv4.py b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv4.py
new file mode 100644
index 0000000..4c654b5
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_maplrdiv4.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_maplrdiv4',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.25,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=2.5,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap.py
new file mode 100644
index 0000000..59e5e30
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan.py
new file mode 100644
index 0000000..a567ccc
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap_noplan',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.0,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=0.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=0.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion.py
new file mode 100644
index 0000000..79cac29
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion.py
@@ -0,0 +1,723 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap_noplan_notempmotion',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 1 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.0,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=0.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=0.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_notempmotion.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_notempmotion.py
new file mode 100644
index 0000000..b10e732
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_notempmotion.py
@@ -0,0 +1,723 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap_notempmotion',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 1 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine2.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine2.py
new file mode 100644
index 0000000..34d6a20
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine2.py
@@ -0,0 +1,729 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap_predrefine2',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] +
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ "refine",
+ ] * 2
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine3.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine3.py
new file mode 100644
index 0000000..3dd0a80
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_predrefine3.py
@@ -0,0 +1,722 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap_predrefine3',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ "refine",
+ ] * 3
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval.py
new file mode 100644
index 0000000..4408418
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval.py
@@ -0,0 +1,757 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occfltrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_supervise_all=True,
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.4,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val_occ.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train_occ.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val_occ.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval.py b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval.py
new file mode 100644
index 0000000..0bb8eef
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval.py
@@ -0,0 +1,757 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 4
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_4gpu_nomap_sepheadvisobsiso_occptrainval',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ temporal_warmup_order=("gnn", "norm", "ffn", "norm", "refine"),
+ warmup_supervise_all=True,
+ warmup_ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ warmup_refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_cls_branch=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ with_visibility_estimation=True,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ loss_visibility=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=0.0,
+ alpha=0.2,
+ loss_weight=1.0,
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ 'gt_visibility',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes',
+ 'fut_boxes_occluded',
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ use_gt_mask=False,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ with_occlusion=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu.py b/projects/configs/sparsedrive_r50_stage2_8gpu.py
new file mode 100644
index 0000000..ba04b85
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_8gpu.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_8gpu',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_noflash.py b/projects/configs/sparsedrive_r50_stage2_8gpu_noflash.py
new file mode 100644
index 0000000..e44ef22
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_8gpu_noflash.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_8gpu_noflash',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_noplan.py b/projects/configs/sparsedrive_r50_stage2_8gpu_noplan.py
new file mode 100644
index 0000000..956b68c
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_8gpu_noplan.py
@@ -0,0 +1,726 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_8gpu_noplan',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.0,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=0.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=0.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_notempmotion.py b/projects/configs/sparsedrive_r50_stage2_8gpu_notempmotion.py
new file mode 100644
index 0000000..eb3cba8
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_8gpu_notempmotion.py
@@ -0,0 +1,725 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_8gpu_notempmotion',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 1 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap.py b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap.py
new file mode 100644
index 0000000..00917cc
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_8gpu_pretrainv1_noflash_nomap',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1_nomap_dn_rotaug.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv2_noflash.py b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv2_noflash.py
new file mode 100644
index 0000000..d4a6929
--- /dev/null
+++ b/projects/configs/sparsedrive_r50_stage2_8gpu_pretrainv2_noflash.py
@@ -0,0 +1,724 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type='WandbLoggerHook',
+ init_kwargs=dict(
+ entity='trailab',
+ project='ForeSight',
+ name='sparsedrive_r50_stage2_8gpu_pretrainv2_noflash',),
+ interval=50)
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=False,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=6,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=False,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1_dn.pth'
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_small_stage1.py b/projects/configs/sparsedrive_small_stage1.py
new file mode 100644
index 0000000..1cbb38b
--- /dev/null
+++ b/projects/configs/sparsedrive_small_stage1.py
@@ -0,0 +1,719 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 64
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 100
+checkpoint_epoch_interval = 20
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type="TensorboardLoggerHook"),
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=False,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=0 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=batch_size,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=4e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.5),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=False,
+ with_planning=False,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
\ No newline at end of file
diff --git a/projects/configs/sparsedrive_small_stage2.py b/projects/configs/sparsedrive_small_stage2.py
new file mode 100644
index 0000000..94cd085
--- /dev/null
+++ b/projects/configs/sparsedrive_small_stage2.py
@@ -0,0 +1,721 @@
+# ================ base config ===================
+version = 'mini'
+version = 'trainval'
+length = {'trainval': 28130, 'mini': 323}
+
+plugin = True
+plugin_dir = "projects/mmdet3d_plugin/"
+dist_params = dict(backend="nccl")
+log_level = "INFO"
+work_dir = None
+
+total_batch_size = 48
+num_gpus = 8
+batch_size = total_batch_size // num_gpus
+num_iters_per_epoch = int(length[version] // (num_gpus * batch_size))
+num_epochs = 10
+checkpoint_epoch_interval = 10
+
+checkpoint_config = dict(
+ interval=num_iters_per_epoch * checkpoint_epoch_interval
+)
+log_config = dict(
+ interval=51,
+ hooks=[
+ dict(type="TextLoggerHook", by_epoch=False),
+ dict(type="TensorboardLoggerHook"),
+ ],
+)
+load_from = None
+resume_from = None
+workflow = [("train", 1)]
+fp16 = dict(loss_scale=32.0)
+input_shape = (704, 256)
+
+
+# ================== model ========================
+class_names = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+map_class_names = [
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+]
+num_classes = len(class_names)
+num_map_classes = len(map_class_names)
+roi_size = (30, 60)
+
+num_sample = 20
+fut_ts = 12
+fut_mode = 6
+ego_fut_ts = 6
+ego_fut_mode = 6
+queue_length = 4 # history + current
+
+embed_dims = 256
+num_groups = 8
+num_decoder = 6
+num_single_frame_decoder = 1
+num_single_frame_decoder_map = 1
+use_deformable_func = True # mmdet3d_plugin/ops/setup.py needs to be executed
+strides = [4, 8, 16, 32]
+num_levels = len(strides)
+num_depth_layers = 3
+drop_out = 0.1
+temporal = True
+temporal_map = True
+decouple_attn = True
+decouple_attn_map = False
+decouple_attn_motion = True
+with_quality_estimation = True
+
+task_config = dict(
+ with_det=True,
+ with_map=True,
+ with_motion_plan=True,
+)
+
+model = dict(
+ type="SparseDrive",
+ use_grid_mask=True,
+ use_deformable_func=use_deformable_func,
+ img_backbone=dict(
+ type="ResNet",
+ depth=50,
+ num_stages=4,
+ frozen_stages=-1,
+ norm_eval=False,
+ style="pytorch",
+ with_cp=True,
+ out_indices=(0, 1, 2, 3),
+ norm_cfg=dict(type="BN", requires_grad=True),
+ pretrained="ckpt/resnet50-19c8e357.pth",
+ ),
+ img_neck=dict(
+ type="FPN",
+ num_outs=num_levels,
+ start_level=0,
+ out_channels=embed_dims,
+ add_extra_convs="on_output",
+ relu_before_extra_convs=True,
+ in_channels=[256, 512, 1024, 2048],
+ ),
+ depth_branch=dict( # for auxiliary supervision only
+ type="DenseDepthNet",
+ embed_dims=embed_dims,
+ num_depth_layers=num_depth_layers,
+ loss_weight=0.2,
+ ),
+ head=dict(
+ type="SparseDriveHead",
+ task_config=task_config,
+ det_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=900,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_det_900.npy",
+ anchor_handler=dict(type="SparseBox3DKeyPointsGenerator"),
+ num_temp_instances=600 if temporal else -1,
+ confidence_decay=0.6,
+ feat_grad=False,
+ ),
+ anchor_encoder=dict(
+ type="SparseBox3DEncoder",
+ vel_dims=3,
+ embed_dims=[128, 32, 32, 64] if decouple_attn else 256,
+ mode="cat" if decouple_attn else "add",
+ output_fc=not decouple_attn,
+ in_loops=1,
+ out_loops=4 if decouple_attn else 2,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder)
+ )[2:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparseBox3DKeyPointsGenerator",
+ num_learnable_pts=6,
+ fix_scale=[
+ [0, 0, 0],
+ [0.45, 0, 0],
+ [-0.45, 0, 0],
+ [0, 0.45, 0],
+ [0, -0.45, 0],
+ [0, 0, 0.45],
+ [0, 0, -0.45],
+ ],
+ ),
+ ),
+ refine_layer=dict(
+ type="SparseBox3DRefinementModule",
+ embed_dims=embed_dims,
+ num_cls=num_classes,
+ refine_yaw=True,
+ with_quality_estimation=with_quality_estimation,
+ ),
+ sampler=dict(
+ type="SparseBox3DTarget",
+ num_dn_groups=0,
+ num_temp_dn_groups=0,
+ dn_noise_scale=[2.0] * 3 + [0.5] * 7,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ cls_weight=2.0,
+ box_weight=0.25,
+ reg_weights=[2.0] * 3 + [0.5] * 3 + [0.0] * 4,
+ cls_wise_reg_weights={
+ class_names.index("traffic_cone"): [
+ 2.0,
+ 2.0,
+ 2.0,
+ 1.0,
+ 1.0,
+ 1.0,
+ 0.0,
+ 0.0,
+ 1.0,
+ 1.0,
+ ],
+ },
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0,
+ ),
+ loss_reg=dict(
+ type="SparseBox3DLoss",
+ loss_box=dict(type="L1Loss", loss_weight=0.25),
+ loss_centerness=dict(type="CrossEntropyLoss", use_sigmoid=True),
+ loss_yawness=dict(type="GaussianFocalLoss"),
+ cls_allow_reverse=[class_names.index("barrier")],
+ ),
+ decoder=dict(type="SparseBox3DDecoder"),
+ reg_weights=[2.0] * 3 + [1.0] * 7,
+ ),
+ map_head=dict(
+ type="Sparse4DHead",
+ cls_threshold_to_reg=0.05,
+ decouple_attn=decouple_attn_map,
+ instance_bank=dict(
+ type="InstanceBank",
+ num_anchor=100,
+ embed_dims=embed_dims,
+ anchor="data/kmeans/kmeans_map_100.npy",
+ anchor_handler=dict(type="SparsePoint3DKeyPointsGenerator"),
+ num_temp_instances=33 if temporal_map else -1,
+ confidence_decay=0.6,
+ feat_grad=True,
+ ),
+ anchor_encoder=dict(
+ type="SparsePoint3DEncoder",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ ),
+ num_single_frame_decoder=num_single_frame_decoder_map,
+ operation_order=(
+ [
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * num_single_frame_decoder_map
+ + [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "ffn",
+ "norm",
+ "refine",
+ ]
+ * (num_decoder - num_single_frame_decoder_map)
+ )[:],
+ temp_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ )
+ if temporal_map
+ else None,
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_map else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims * 2,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 4,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ deformable_model=dict(
+ type="DeformableFeatureAggregation",
+ embed_dims=embed_dims,
+ num_groups=num_groups,
+ num_levels=num_levels,
+ num_cams=6,
+ attn_drop=0.15,
+ use_deformable_func=use_deformable_func,
+ use_camera_embed=True,
+ residual_mode="cat",
+ kps_generator=dict(
+ type="SparsePoint3DKeyPointsGenerator",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_learnable_pts=3,
+ fix_height=(0, 0.5, -0.5, 1, -1),
+ ground_height=-1.84023, # ground height in lidar frame
+ ),
+ ),
+ refine_layer=dict(
+ type="SparsePoint3DRefinementModule",
+ embed_dims=embed_dims,
+ num_sample=num_sample,
+ num_cls=num_map_classes,
+ ),
+ sampler=dict(
+ type="SparsePoint3DTarget",
+ assigner=dict(
+ type='HungarianLinesAssigner',
+ cost=dict(
+ type='MapQueriesCost',
+ cls_cost=dict(type='FocalLossCost', weight=1.0),
+ reg_cost=dict(type='LinesL1Cost', weight=10.0, beta=0.01, permute=True),
+ ),
+ ),
+ num_cls=num_map_classes,
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ loss_cls=dict(
+ type="FocalLoss",
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0,
+ ),
+ loss_reg=dict(
+ type="SparseLineLoss",
+ loss_line=dict(
+ type='LinesL1Loss',
+ loss_weight=10.0,
+ beta=0.01,
+ ),
+ num_sample=num_sample,
+ roi_size=roi_size,
+ ),
+ decoder=dict(type="SparsePoint3DDecoder"),
+ reg_weights=[1.0] * 40,
+ gt_cls_key="gt_map_labels",
+ gt_reg_key="gt_map_pts",
+ gt_id_key="map_instance_id",
+ with_instance_id=False,
+ task_prefix='map',
+ ),
+ motion_plan_head=dict(
+ type='MotionPlanningHead',
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ motion_anchor=f'data/kmeans/kmeans_motion_{fut_mode}.npy',
+ plan_anchor=f'data/kmeans/kmeans_plan_{ego_fut_mode}.npy',
+ embed_dims=embed_dims,
+ decouple_attn=decouple_attn_motion,
+ instance_queue=dict(
+ type="InstanceQueue",
+ embed_dims=embed_dims,
+ queue_length=queue_length,
+ tracking_threshold=0.2,
+ feature_map_scale=(input_shape[1]/strides[-1], input_shape[0]/strides[-1]),
+ ),
+ operation_order=(
+ [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "cross_gnn",
+ "norm",
+ "ffn",
+ "norm",
+ ] * 3 +
+ [
+ "refine",
+ ]
+ ),
+ temp_graph_model=dict(
+ type="MultiheadAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims if not decouple_attn_motion else embed_dims * 2,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ cross_graph_model=dict(
+ type="MultiheadFlashAttention",
+ embed_dims=embed_dims,
+ num_heads=num_groups,
+ batch_first=True,
+ dropout=drop_out,
+ ),
+ norm_layer=dict(type="LN", normalized_shape=embed_dims),
+ ffn=dict(
+ type="AsymmetricFFN",
+ in_channels=embed_dims,
+ pre_norm=dict(type="LN"),
+ embed_dims=embed_dims,
+ feedforward_channels=embed_dims * 2,
+ num_fcs=2,
+ ffn_drop=drop_out,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ),
+ refine_layer=dict(
+ type="MotionPlanningRefinementModule",
+ embed_dims=embed_dims,
+ fut_ts=fut_ts,
+ fut_mode=fut_mode,
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ motion_sampler=dict(
+ type="MotionTarget",
+ ),
+ motion_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.2
+ ),
+ motion_loss_reg=dict(type='L1Loss', loss_weight=0.2),
+ planning_sampler=dict(
+ type="PlanningTarget",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ ),
+ plan_loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=0.5,
+ ),
+ plan_loss_reg=dict(type='L1Loss', loss_weight=1.0),
+ plan_loss_status=dict(type='L1Loss', loss_weight=1.0),
+ motion_decoder=dict(type="SparseBox3DMotionDecoder"),
+ planning_decoder=dict(
+ type="HierarchicalPlanningDecoder",
+ ego_fut_ts=ego_fut_ts,
+ ego_fut_mode=ego_fut_mode,
+ use_rescore=True,
+ ),
+ num_det=50,
+ num_map=10,
+ ),
+ ),
+)
+
+# ================== data ========================
+dataset_type = "NuScenes3DDataset"
+data_root = "data/nuscenes/"
+anno_root = "data/infos/" if version == 'trainval' else "data/infos/mini/"
+file_client_args = dict(backend="disk")
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
+)
+train_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(
+ type="LoadPointsFromFile",
+ coord_type="LIDAR",
+ load_dim=5,
+ use_dim=5,
+ file_client_args=file_client_args,
+ ),
+ dict(type="ResizeCropFlipImage"),
+ dict(
+ type="MultiScaleDepthMapGenerator",
+ downsample=strides[:num_depth_layers],
+ ),
+ dict(type="BBoxRotation"),
+ dict(type="PhotoMetricDistortionMultiViewImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=False,
+ normalize=False,
+ sample_num=num_sample,
+ permute=True,
+ ),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ "gt_depth",
+ "focal",
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp", "instance_id"],
+ ),
+]
+test_pipeline = [
+ dict(type="LoadMultiViewImageFromFiles", to_float32=True),
+ dict(type="ResizeCropFlipImage"),
+ dict(type="NormalizeMultiviewImage", **img_norm_cfg),
+ dict(type="NuScenesSparse4DAdaptor"),
+ dict(
+ type="Collect",
+ keys=[
+ "img",
+ "timestamp",
+ "projection_mat",
+ "image_wh",
+ 'ego_status',
+ 'gt_ego_fut_cmd',
+ ],
+ meta_keys=["T_global", "T_global_inv", "timestamp"],
+ ),
+]
+eval_pipeline = [
+ dict(
+ type="CircleObjectRangeFilter",
+ class_dist_thred=[55] * len(class_names),
+ ),
+ dict(type="InstanceNameFilter", classes=class_names),
+ dict(
+ type='VectorizeMap',
+ roi_size=roi_size,
+ simplify=True,
+ normalize=False,
+ ),
+ dict(
+ type='Collect',
+ keys=[
+ 'vectors',
+ "gt_bboxes_3d",
+ "gt_labels_3d",
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'fut_boxes'
+ ],
+ meta_keys=['token', 'timestamp']
+ ),
+]
+
+input_modality = dict(
+ use_lidar=False,
+ use_camera=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+)
+
+data_basic_config = dict(
+ type=dataset_type,
+ data_root=data_root,
+ classes=class_names,
+ map_classes=map_class_names,
+ modality=input_modality,
+ version="v1.0-trainval",
+)
+eval_config = dict(
+ **data_basic_config,
+ ann_file=anno_root + 'nuscenes_infos_val.pkl',
+ pipeline=eval_pipeline,
+ test_mode=True,
+)
+data_aug_conf = {
+ "resize_lim": (0.40, 0.47),
+ "final_dim": input_shape[::-1],
+ "bot_pct_lim": (0.0, 0.0),
+ "rot_lim": (-5.4, 5.4),
+ "H": 900,
+ "W": 1600,
+ "rand_flip": True,
+ "rot3d_range": [0, 0],
+}
+
+data = dict(
+ samples_per_gpu=batch_size,
+ workers_per_gpu=batch_size,
+ train=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_train.pkl",
+ pipeline=train_pipeline,
+ test_mode=False,
+ data_aug_conf=data_aug_conf,
+ with_seq_flag=True,
+ sequences_split_num=2,
+ keep_consistent_seq_aug=True,
+ ),
+ val=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+ test=dict(
+ **data_basic_config,
+ ann_file=anno_root + "nuscenes_infos_val.pkl",
+ pipeline=test_pipeline,
+ data_aug_conf=data_aug_conf,
+ test_mode=True,
+ eval_config=eval_config,
+ ),
+)
+
+# ================== training ========================
+optimizer = dict(
+ type="AdamW",
+ lr=3e-4,
+ weight_decay=0.001,
+ paramwise_cfg=dict(
+ custom_keys={
+ "img_backbone": dict(lr_mult=0.1),
+ }
+ ),
+)
+optimizer_config = dict(grad_clip=dict(max_norm=25, norm_type=2))
+lr_config = dict(
+ policy="CosineAnnealing",
+ warmup="linear",
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ min_lr_ratio=1e-3,
+)
+runner = dict(
+ type="IterBasedRunner",
+ max_iters=num_iters_per_epoch * num_epochs,
+)
+
+# ================== eval ========================
+eval_mode = dict(
+ with_det=True,
+ with_tracking=True,
+ with_map=True,
+ with_motion=True,
+ with_planning=True,
+ tracking_threshold=0.2,
+ motion_threshhold=0.2,
+)
+evaluation = dict(
+ interval=num_iters_per_epoch*checkpoint_epoch_interval,
+ eval_mode=eval_mode,
+)
+# ================== pretrained model ========================
+load_from = 'ckpt/sparsedrive_stage1.pth'
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/__init__.py b/projects/mmdet3d_plugin/__init__.py
new file mode 100644
index 0000000..e6f4ea2
--- /dev/null
+++ b/projects/mmdet3d_plugin/__init__.py
@@ -0,0 +1,4 @@
+from .datasets import *
+from .models import *
+from .apis import *
+from .core.evaluation import *
diff --git a/projects/mmdet3d_plugin/apis/__init__.py b/projects/mmdet3d_plugin/apis/__init__.py
new file mode 100644
index 0000000..baab0c6
--- /dev/null
+++ b/projects/mmdet3d_plugin/apis/__init__.py
@@ -0,0 +1,4 @@
+from .train import custom_train_model
+from .mmdet_train import custom_train_detector
+
+# from .test import custom_multi_gpu_test
diff --git a/projects/mmdet3d_plugin/apis/mmdet_train.py b/projects/mmdet3d_plugin/apis/mmdet_train.py
new file mode 100644
index 0000000..ad6dc60
--- /dev/null
+++ b/projects/mmdet3d_plugin/apis/mmdet_train.py
@@ -0,0 +1,219 @@
+# ---------------------------------------------
+# Copyright (c) OpenMMLab. All rights reserved.
+# ---------------------------------------------
+# Modified by Zhiqi Li
+# ---------------------------------------------
+import random
+import warnings
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (
+ HOOKS,
+ DistSamplerSeedHook,
+ EpochBasedRunner,
+ Fp16OptimizerHook,
+ OptimizerHook,
+ build_optimizer,
+ build_runner,
+ get_dist_info,
+)
+from mmcv.utils import build_from_cfg
+
+from mmdet.core import EvalHook
+
+from mmdet.datasets import build_dataset, replace_ImageToTensor
+from mmdet.utils import get_root_logger
+import time
+import os.path as osp
+from projects.mmdet3d_plugin.datasets.builder import build_dataloader
+from projects.mmdet3d_plugin.core.evaluation.eval_hooks import (
+ CustomDistEvalHook,
+)
+from projects.mmdet3d_plugin.datasets import custom_build_dataset
+
+
+def custom_train_detector(
+ model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None,
+):
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ # assert len(dataset)==1s
+ if "imgs_per_gpu" in cfg.data:
+ logger.warning(
+ '"imgs_per_gpu" is deprecated in MMDet V2.0. '
+ 'Please use "samples_per_gpu" instead'
+ )
+ if "samples_per_gpu" in cfg.data:
+ logger.warning(
+ f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
+ f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
+ f"={cfg.data.imgs_per_gpu} is used in this experiments"
+ )
+ else:
+ logger.warning(
+ 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
+ f"{cfg.data.imgs_per_gpu} in this experiments"
+ )
+ cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
+
+ if "runner" in cfg:
+ runner_type = cfg.runner["type"]
+ else:
+ runner_type = "EpochBasedRunner"
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed,
+ nonshuffler_sampler=dict(
+ type="DistributedSampler"
+ ), # dict(type='DistributedSampler'),
+ runner_type=runner_type,
+ )
+ for ds in dataset
+ ]
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get("find_unused_parameters", False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters,
+ )
+
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids
+ )
+
+ # build runner
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ if "runner" not in cfg:
+ cfg.runner = {
+ "type": "EpochBasedRunner",
+ "max_epochs": cfg.total_epochs,
+ }
+ warnings.warn(
+ "config is now expected to have a `runner` section, "
+ "please set `runner` in your config.",
+ UserWarning,
+ )
+ else:
+ if "total_epochs" in cfg:
+ assert cfg.total_epochs == cfg.runner.max_epochs
+
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta,
+ ),
+ )
+
+ # an ugly workaround to make .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # fp16 setting
+ fp16_cfg = cfg.get("fp16", None)
+ if fp16_cfg is not None:
+ optimizer_config = Fp16OptimizerHook(
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed
+ )
+ elif distributed and "type" not in cfg.optimizer_config:
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ runner.register_training_hooks(
+ cfg.lr_config,
+ optimizer_config,
+ cfg.checkpoint_config,
+ cfg.log_config,
+ cfg.get("momentum_config", None),
+ )
+
+ # register profiler hook
+ # trace_config = dict(type='tb_trace', dir_name='work_dir')
+ # profiler_config = dict(on_trace_ready=trace_config)
+ # runner.register_profiler_hook(profiler_config)
+
+ if distributed:
+ if isinstance(runner, EpochBasedRunner):
+ runner.register_hook(DistSamplerSeedHook())
+
+ # register eval hooks
+ if validate:
+ # Support batch_size > 1 in validation
+ val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1)
+ if val_samples_per_gpu > 1:
+ assert False
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.val.pipeline = replace_ImageToTensor(
+ cfg.data.val.pipeline
+ )
+ val_dataset = custom_build_dataset(cfg.data.val, dict(test_mode=True))
+
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=val_samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False,
+ nonshuffler_sampler=dict(type="DistributedSampler"),
+ )
+ eval_cfg = cfg.get("evaluation", {})
+ eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner"
+ eval_cfg["jsonfile_prefix"] = osp.join(
+ "val",
+ cfg.work_dir,
+ time.ctime().replace(" ", "_").replace(":", "_"),
+ )
+ eval_hook = CustomDistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
+
+ # user-defined hooks
+ if cfg.get("custom_hooks", None):
+ custom_hooks = cfg.custom_hooks
+ assert isinstance(
+ custom_hooks, list
+ ), f"custom_hooks expect list type, but got {type(custom_hooks)}"
+ for hook_cfg in cfg.custom_hooks:
+ assert isinstance(hook_cfg, dict), (
+ "Each item in custom_hooks expects dict type, but got "
+ f"{type(hook_cfg)}"
+ )
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop("priority", "NORMAL")
+ hook = build_from_cfg(hook_cfg, HOOKS)
+ runner.register_hook(hook, priority=priority)
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/projects/mmdet3d_plugin/apis/test.py b/projects/mmdet3d_plugin/apis/test.py
new file mode 100644
index 0000000..f5fdcd3
--- /dev/null
+++ b/projects/mmdet3d_plugin/apis/test.py
@@ -0,0 +1,171 @@
+# ---------------------------------------------
+# Copyright (c) OpenMMLab. All rights reserved.
+# ---------------------------------------------
+# Modified by Zhiqi Li
+# ---------------------------------------------
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import mmcv
+import torch
+import torch.distributed as dist
+from mmcv.image import tensor2imgs
+from mmcv.runner import get_dist_info
+
+from mmdet.core import encode_mask_results
+
+
+import mmcv
+import numpy as np
+import pycocotools.mask as mask_util
+
+
+def custom_encode_mask_results(mask_results):
+ """Encode bitmap mask to RLE code. Semantic Masks only
+ Args:
+ mask_results (list | tuple[list]): bitmap mask results.
+ In mask scoring rcnn, mask_results is a tuple of (segm_results,
+ segm_cls_score).
+ Returns:
+ list | tuple: RLE encoded mask.
+ """
+ cls_segms = mask_results
+ num_classes = len(cls_segms)
+ encoded_mask_results = []
+ for i in range(len(cls_segms)):
+ encoded_mask_results.append(
+ mask_util.encode(
+ np.array(
+ cls_segms[i][:, :, np.newaxis], order="F", dtype="uint8"
+ )
+ )[0]
+ ) # encoded with RLE
+ return [encoded_mask_results]
+
+
+def custom_multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ bbox_results = []
+ mask_results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ have_mask = False
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ # encode mask results
+ if isinstance(result, dict):
+ if "bbox_results" in result.keys():
+ bbox_result = result["bbox_results"]
+ batch_size = len(result["bbox_results"])
+ bbox_results.extend(bbox_result)
+ if (
+ "mask_results" in result.keys()
+ and result["mask_results"] is not None
+ ):
+ mask_result = custom_encode_mask_results(
+ result["mask_results"]
+ )
+ mask_results.extend(mask_result)
+ have_mask = True
+ else:
+ batch_size = len(result)
+ bbox_results.extend(result)
+
+ if rank == 0:
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ bbox_results = collect_results_gpu(bbox_results, len(dataset))
+ if have_mask:
+ mask_results = collect_results_gpu(mask_results, len(dataset))
+ else:
+ mask_results = None
+ else:
+ bbox_results = collect_results_cpu(bbox_results, len(dataset), tmpdir)
+ tmpdir = tmpdir + "_mask" if tmpdir is not None else None
+ if have_mask:
+ mask_results = collect_results_cpu(
+ mask_results, len(dataset), tmpdir
+ )
+ else:
+ mask_results = None
+
+ if mask_results is None:
+ return bbox_results
+ return {"bbox_results": bbox_results, "mask_results": mask_results}
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full(
+ (MAX_LEN,), 32, dtype=torch.uint8, device="cuda"
+ )
+ if rank == 0:
+ mmcv.mkdir_or_exist(".dist_test")
+ tmpdir = tempfile.mkdtemp(dir=".dist_test")
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device="cuda"
+ )
+ dir_tensor[: len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f"part_{rank}.pkl"))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f"part_{i}.pkl")
+ part_list.append(mmcv.load(part_file))
+ # sort the results
+ ordered_results = []
+ """
+ bacause we change the sample of the evaluation stage to make sure that
+ each gpu will handle continuous sample,
+ """
+ # for res in zip(*part_list):
+ for res in part_list:
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ collect_results_cpu(result_part, size)
diff --git a/projects/mmdet3d_plugin/apis/train.py b/projects/mmdet3d_plugin/apis/train.py
new file mode 100644
index 0000000..34cfa3e
--- /dev/null
+++ b/projects/mmdet3d_plugin/apis/train.py
@@ -0,0 +1,62 @@
+# ---------------------------------------------
+# Copyright (c) OpenMMLab. All rights reserved.
+# ---------------------------------------------
+# Modified by Zhiqi Li
+# ---------------------------------------------
+
+from .mmdet_train import custom_train_detector
+# from mmseg.apis import train_segmentor
+from mmdet.apis import train_detector
+
+
+def custom_train_model(
+ model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None,
+):
+ """A function wrapper for launching model training according to cfg.
+
+ Because we need different eval_hook in runner. Should be deprecated in the
+ future.
+ """
+ if cfg.model.type in ["EncoderDecoder3D"]:
+ assert False
+ else:
+ custom_train_detector(
+ model,
+ dataset,
+ cfg,
+ distributed=distributed,
+ validate=validate,
+ timestamp=timestamp,
+ meta=meta,
+ )
+
+
+def train_model(
+ model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None,
+):
+ """A function wrapper for launching model training according to cfg.
+
+ Because we need different eval_hook in runner. Should be deprecated in the
+ future.
+ """
+ train_detector(
+ model,
+ dataset,
+ cfg,
+ distributed=distributed,
+ validate=validate,
+ timestamp=timestamp,
+ meta=meta,
+ )
diff --git a/projects/mmdet3d_plugin/core/box3d.py b/projects/mmdet3d_plugin/core/box3d.py
new file mode 100644
index 0000000..93447e3
--- /dev/null
+++ b/projects/mmdet3d_plugin/core/box3d.py
@@ -0,0 +1,3 @@
+X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11)) # undecoded
+CNS, YNS = 0, 1 # centerness and yawness indices in quality
+YAW = 6 # decoded
diff --git a/projects/mmdet3d_plugin/core/evaluation/__init__.py b/projects/mmdet3d_plugin/core/evaluation/__init__.py
new file mode 100644
index 0000000..d92421c
--- /dev/null
+++ b/projects/mmdet3d_plugin/core/evaluation/__init__.py
@@ -0,0 +1 @@
+from .eval_hooks import CustomDistEvalHook
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/core/evaluation/eval_hooks.py b/projects/mmdet3d_plugin/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000..6a33bb9
--- /dev/null
+++ b/projects/mmdet3d_plugin/core/evaluation/eval_hooks.py
@@ -0,0 +1,97 @@
+# Note: Considering that MMCV's EvalHook updated its interface in V1.3.16,
+# in order to avoid strong version dependency, we did not directly
+# inherit EvalHook but BaseDistEvalHook.
+
+import bisect
+import os.path as osp
+
+import mmcv
+import torch.distributed as dist
+from mmcv.runner import DistEvalHook as BaseDistEvalHook
+from mmcv.runner import EvalHook as BaseEvalHook
+from torch.nn.modules.batchnorm import _BatchNorm
+from mmdet.core.evaluation.eval_hooks import DistEvalHook
+
+
+def _calc_dynamic_intervals(start_interval, dynamic_interval_list):
+ assert mmcv.is_list_of(dynamic_interval_list, tuple)
+
+ dynamic_milestones = [0]
+ dynamic_milestones.extend(
+ [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]
+ )
+ dynamic_intervals = [start_interval]
+ dynamic_intervals.extend(
+ [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]
+ )
+ return dynamic_milestones, dynamic_intervals
+
+
+class CustomDistEvalHook(BaseDistEvalHook):
+ def __init__(self, *args, dynamic_intervals=None, **kwargs):
+ super(CustomDistEvalHook, self).__init__(*args, **kwargs)
+ self.use_dynamic_intervals = dynamic_intervals is not None
+ if self.use_dynamic_intervals:
+ (
+ self.dynamic_milestones,
+ self.dynamic_intervals,
+ ) = _calc_dynamic_intervals(self.interval, dynamic_intervals)
+
+ def _decide_interval(self, runner):
+ if self.use_dynamic_intervals:
+ progress = runner.epoch if self.by_epoch else runner.iter
+ step = bisect.bisect(self.dynamic_milestones, (progress + 1))
+ # Dynamically modify the evaluation interval
+ self.interval = self.dynamic_intervals[step - 1]
+
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training by epoch."""
+ self._decide_interval(runner)
+ super().before_train_epoch(runner)
+
+ def before_train_iter(self, runner):
+ self._decide_interval(runner)
+ super().before_train_iter(runner)
+
+ def _do_evaluate(self, runner):
+ """perform evaluation and save ckpt."""
+ # Synchronization of BatchNorm's buffer (running_mean
+ # and running_var) is not supported in the DDP of pytorch,
+ # which may cause the inconsistent performance of models in
+ # different ranks, so we broadcast BatchNorm's buffers
+ # of rank 0 to other ranks to avoid this.
+ if self.broadcast_bn_buffer:
+ model = runner.model
+ for name, module in model.named_modules():
+ if (
+ isinstance(module, _BatchNorm)
+ and module.track_running_stats
+ ):
+ dist.broadcast(module.running_var, 0)
+ dist.broadcast(module.running_mean, 0)
+
+ if not self._should_evaluate(runner):
+ return
+
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, ".eval_hook")
+
+ from projects.mmdet3d_plugin.apis.test import (
+ custom_multi_gpu_test,
+ ) # to solve circlur import
+
+ results = custom_multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect,
+ )
+ if runner.rank == 0:
+ print("\n")
+ runner.log_buffer.output["eval_iter_num"] = len(self.dataloader)
+
+ key_score = self.evaluate(runner, results)
+
+ if self.save_best:
+ self._save_ckpt(runner, key_score)
diff --git a/projects/mmdet3d_plugin/datasets/__init__.py b/projects/mmdet3d_plugin/datasets/__init__.py
new file mode 100644
index 0000000..ac90a58
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/__init__.py
@@ -0,0 +1,9 @@
+from .nuscenes_3d_dataset import NuScenes3DDataset
+from .builder import *
+from .pipelines import *
+from .samplers import *
+
+__all__ = [
+ 'NuScenes3DDataset',
+ "custom_build_dataset",
+]
diff --git a/projects/mmdet3d_plugin/datasets/builder.py b/projects/mmdet3d_plugin/datasets/builder.py
new file mode 100644
index 0000000..ab30f9d
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/builder.py
@@ -0,0 +1,192 @@
+import copy
+import platform
+import random
+from functools import partial
+
+import numpy as np
+from mmcv.parallel import collate
+from mmcv.runner import get_dist_info
+from mmcv.utils import Registry, build_from_cfg
+from torch.utils.data import DataLoader
+
+from mmdet.datasets.samplers import GroupSampler
+from projects.mmdet3d_plugin.datasets.samplers import (
+ GroupInBatchSampler,
+ DistributedGroupSampler,
+ DistributedSampler,
+ build_sampler
+)
+
+
+def build_dataloader(
+ dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ shuffler_sampler=None,
+ nonshuffler_sampler=None,
+ runner_type="EpochBasedRunner",
+ **kwargs
+):
+ """Build PyTorch DataLoader.
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ kwargs: any keyword argument to be used to initialize DataLoader
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ batch_sampler = None
+ if runner_type == 'IterBasedRunner':
+ print("Use GroupInBatchSampler !!!")
+ batch_sampler = GroupInBatchSampler(
+ dataset,
+ samples_per_gpu,
+ world_size,
+ rank,
+ seed=seed,
+ )
+ batch_size = 1
+ sampler = None
+ num_workers = workers_per_gpu
+ elif dist:
+ # DistributedGroupSampler will definitely shuffle the data to satisfy
+ # that images on each GPU are in the same group
+ if shuffle:
+ print("Use DistributedGroupSampler !!!")
+ sampler = build_sampler(
+ shuffler_sampler
+ if shuffler_sampler is not None
+ else dict(type="DistributedGroupSampler"),
+ dict(
+ dataset=dataset,
+ samples_per_gpu=samples_per_gpu,
+ num_replicas=world_size,
+ rank=rank,
+ seed=seed,
+ ),
+ )
+ else:
+ sampler = build_sampler(
+ nonshuffler_sampler
+ if nonshuffler_sampler is not None
+ else dict(type="DistributedSampler"),
+ dict(
+ dataset=dataset,
+ num_replicas=world_size,
+ rank=rank,
+ shuffle=shuffle,
+ seed=seed,
+ ),
+ )
+
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ # assert False, 'not support in bevformer'
+ print("WARNING!!!!, Only can be used for obtain inference speed!!!!")
+ sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ init_fn = (
+ partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed)
+ if seed is not None
+ else None
+ )
+
+ data_loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=False,
+ worker_init_fn=init_fn,
+ **kwargs
+ )
+
+ return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
+
+
+# Copyright (c) OpenMMLab. All rights reserved.
+import platform
+from mmcv.utils import Registry, build_from_cfg
+
+from mmdet.datasets import DATASETS
+from mmdet.datasets.builder import _concat_dataset
+
+if platform.system() != "Windows":
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ base_soft_limit = rlimit[0]
+ hard_limit = rlimit[1]
+ soft_limit = min(max(4096, base_soft_limit), hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+OBJECTSAMPLERS = Registry("Object sampler")
+
+
+def custom_build_dataset(cfg, default_args=None):
+ try:
+ from mmdet3d.datasets.dataset_wrappers import CBGSDataset
+ except:
+ CBGSDataset = None
+ from mmdet.datasets.dataset_wrappers import (
+ ClassBalancedDataset,
+ ConcatDataset,
+ RepeatDataset,
+ )
+
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset(
+ [custom_build_dataset(c, default_args) for c in cfg]
+ )
+ elif cfg["type"] == "ConcatDataset":
+ dataset = ConcatDataset(
+ [custom_build_dataset(c, default_args) for c in cfg["datasets"]],
+ cfg.get("separate_eval", True),
+ )
+ elif cfg["type"] == "RepeatDataset":
+ dataset = RepeatDataset(
+ custom_build_dataset(cfg["dataset"], default_args), cfg["times"]
+ )
+ elif cfg["type"] == "ClassBalancedDataset":
+ dataset = ClassBalancedDataset(
+ custom_build_dataset(cfg["dataset"], default_args),
+ cfg["oversample_thr"],
+ )
+ elif cfg["type"] == "CBGSDataset":
+ dataset = CBGSDataset(
+ custom_build_dataset(cfg["dataset"], default_args)
+ )
+ elif isinstance(cfg.get("ann_file"), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/__init__.py b/projects/mmdet3d_plugin/datasets/evaluation/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/det/__init__.py b/projects/mmdet3d_plugin/datasets/evaluation/det/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/det/occluded_det_eval.py b/projects/mmdet3d_plugin/datasets/evaluation/det/occluded_det_eval.py
new file mode 100644
index 0000000..0c47320
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/evaluation/det/occluded_det_eval.py
@@ -0,0 +1,521 @@
+import numpy as np
+from collections import Counter
+from typing import Callable, Dict, Optional, Tuple
+
+from nuscenes.eval.common.data_classes import EvalBoxes
+from nuscenes.eval.common.loaders import load_gt, add_center_dist
+from nuscenes.eval.common.utils import (
+ center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean,
+)
+from nuscenes.eval.detection.data_classes import DetectionBox, DetectionMetricData
+from nuscenes.eval.detection.evaluate import NuScenesEval
+
+
+# ---------------------------------------------------------------------------
+# Adaptive matching threshold coefficients (UniTraj class mapping)
+# Formula: d = dist_th + alpha*t + beta*v*t + gamma*a*t^2
+# ---------------------------------------------------------------------------
+
+ADAPTIVE_COEFFS = {
+ 'vehicle': (0.0568, 0.1962, 0.2133),
+ 'cyclist': (0.1023, 0.1861, 0.2266),
+ 'pedestrian': (0.2641, 0.1457, 0.1774),
+}
+
+NUSCENES_TO_UNITRAJ = {
+ 'car': 'vehicle',
+ 'truck': 'vehicle',
+ 'construction_vehicle': 'vehicle',
+ 'bus': 'vehicle',
+ 'trailer': 'vehicle',
+ 'motorcycle': 'cyclist',
+ 'bicycle': 'cyclist',
+ 'pedestrian': 'pedestrian',
+ # barrier, traffic_cone: no adaptive threshold
+}
+
+
+# ---------------------------------------------------------------------------
+# Adaptive threshold helpers
+# ---------------------------------------------------------------------------
+
+def _ann_speed(nusc, ann: dict) -> float:
+ """Speed (m/s) of an annotation estimated from position difference to prev."""
+ if ann['prev'] == '':
+ return 0.0
+ prev_ann = nusc.get('sample_annotation', ann['prev'])
+ curr_ts = nusc.get('sample', ann['sample_token'])['timestamp']
+ prev_ts = nusc.get('sample', prev_ann['sample_token'])['timestamp']
+ dt = (curr_ts - prev_ts) * 1e-6
+ if dt <= 0:
+ return 0.0
+ dx = ann['translation'][0] - prev_ann['translation'][0]
+ dy = ann['translation'][1] - prev_ann['translation'][1]
+ return np.sqrt(dx ** 2 + dy ** 2) / dt
+
+
+def _occ_metadata(nusc, ann_token: str) -> Tuple[float, float, float]:
+ """Return (t, v, a) for an occluded annotation.
+
+ t : occlusion duration in seconds (≥ 0.5 s since current frame is occluded)
+ v : speed at the last visible annotation (m/s)
+ a : acceleration magnitude at the last visible annotation (m/s²)
+ """
+ ann = nusc.get('sample_annotation', ann_token)
+
+ # Walk prev-links counting consecutive occluded frames and finding the
+ # last visible annotation and the one before it.
+ t_frames = 1 # current frame counts as 1 occluded frame
+ last_vis = None
+ prev_of_last_vis = None
+
+ prev_token = ann['prev']
+ while prev_token != '':
+ prev_ann = nusc.get('sample_annotation', prev_token)
+ if prev_ann['num_lidar_pts'] > 0:
+ last_vis = prev_ann
+ if prev_ann['prev'] != '':
+ prev_of_last_vis = nusc.get('sample_annotation', prev_ann['prev'])
+ break
+ t_frames += 1
+ prev_token = prev_ann['prev']
+
+ t = t_frames * 0.5 # NuScenes is 2 Hz → 0.5 s per frame
+
+ v = _ann_speed(nusc, last_vis) if last_vis is not None else 0.0
+
+ a = 0.0
+ if last_vis is not None and prev_of_last_vis is not None:
+ v_last = _ann_speed(nusc, last_vis)
+ v_prev = _ann_speed(nusc, prev_of_last_vis)
+ curr_ts = nusc.get('sample', last_vis['sample_token'])['timestamp']
+ prev_ts = nusc.get('sample', prev_of_last_vis['sample_token'])['timestamp']
+ dt = (curr_ts - prev_ts) * 1e-6
+ if dt > 0:
+ a = abs(v_last - v_prev) / dt
+
+ return t, v, a
+
+
+def _adaptive_dist_th(
+ base_dist_th: float,
+ class_name: str,
+ t: float,
+ v: float,
+ a: float,
+) -> float:
+ """d = dist_th + alpha*t + beta*v*t + gamma*a*t^2"""
+ unitraj_cls = NUSCENES_TO_UNITRAJ.get(class_name)
+ if unitraj_cls is None:
+ return base_dist_th # static class (barrier, traffic_cone)
+ alpha, beta, gamma = ADAPTIVE_COEFFS[unitraj_cls]
+ return base_dist_th + alpha * t + beta * v * t + gamma * a * t ** 2
+
+
+# ---------------------------------------------------------------------------
+# Custom accumulate with three-outcome matching
+# ---------------------------------------------------------------------------
+
+def accumulate_with_ignore(
+ gt_boxes: EvalBoxes,
+ pred_boxes: EvalBoxes,
+ ignore_boxes: EvalBoxes,
+ class_name: str,
+ dist_fcn: Callable,
+ dist_th: float,
+ per_gt_dist_ths: Optional[Dict[Tuple[str, int], float]] = None,
+ verbose: bool = False,
+) -> DetectionMetricData:
+ """AP accumulation with a three-outcome matching rule.
+
+ For each prediction (processed in descending confidence order):
+
+ 1. **TP** – prediction matches an unmatched visible GT box within its
+ effective distance threshold (adaptive if per_gt_dist_ths provided,
+ otherwise the fixed dist_th).
+ 2. **Ignored** – prediction does not match visible GT, but matches an
+ ignore box within dist_th. The prediction is excluded from both the
+ numerator and denominator of the precision-recall curve.
+ 3. **FP** – prediction matches neither.
+
+ Parameters
+ ----------
+ gt_boxes: Visible GT boxes used for scoring (TP/FP/recall denominator).
+ pred_boxes: All model predictions.
+ ignore_boxes: GT boxes that neutralise unmatched preds.
+ class_name: Detection class to evaluate.
+ dist_fcn: BEV distance function.
+ dist_th: Base match / ignore distance threshold in metres.
+ per_gt_dist_ths: Optional dict mapping (sample_token, gt_idx) → adaptive
+ threshold for TP matching. Ignore matching always uses
+ the fixed dist_th.
+ """
+ npos = len([1 for b in gt_boxes.all if b.detection_name == class_name])
+ if verbose:
+ print(f'Found {npos} GT of class {class_name} across '
+ f'{len(gt_boxes.sample_tokens)} samples.')
+
+ if npos == 0:
+ return DetectionMetricData.no_predictions()
+
+ ignore_by_token = {
+ t: [b for b in ignore_boxes[t] if b.detection_name == class_name]
+ for t in ignore_boxes.sample_tokens
+ }
+
+ pred_boxes_list = [b for b in pred_boxes.all if b.detection_name == class_name]
+ pred_confs = [b.detection_score for b in pred_boxes_list]
+ sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(pred_confs))][::-1]
+
+ tp = []
+ fp = []
+ conf = []
+ match_data = {
+ 'trans_err': [], 'vel_err': [], 'scale_err': [],
+ 'orient_err': [], 'attr_err': [], 'conf': [],
+ }
+
+ taken = set()
+
+ for ind in sortind:
+ pred_box = pred_boxes_list[ind]
+
+ # --- Step 1: find nearest unmatched GT ---
+ min_dist = np.inf
+ match_gt_idx = None
+ for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]):
+ if gt_box.detection_name != class_name:
+ continue
+ if (pred_box.sample_token, gt_idx) in taken:
+ continue
+ d = dist_fcn(gt_box, pred_box)
+ if d < min_dist:
+ min_dist = d
+ match_gt_idx = gt_idx
+
+ # Resolve effective threshold for the nearest GT box.
+ if match_gt_idx is not None and per_gt_dist_ths is not None:
+ eff_dist_th = per_gt_dist_ths.get(
+ (pred_box.sample_token, match_gt_idx), dist_th
+ )
+ else:
+ eff_dist_th = dist_th
+
+ if min_dist < eff_dist_th:
+ taken.add((pred_box.sample_token, match_gt_idx))
+ tp.append(1)
+ fp.append(0)
+ conf.append(pred_box.detection_score)
+
+ gt_match = gt_boxes[pred_box.sample_token][match_gt_idx]
+ match_data['trans_err'].append(center_distance(gt_match, pred_box))
+ match_data['vel_err'].append(velocity_l2(gt_match, pred_box))
+ match_data['scale_err'].append(1 - scale_iou(gt_match, pred_box))
+ period = np.pi if class_name == 'barrier' else 2 * np.pi
+ match_data['orient_err'].append(yaw_diff(gt_match, pred_box, period=period))
+ match_data['attr_err'].append(1 - attr_acc(gt_match, pred_box))
+ match_data['conf'].append(pred_box.detection_score)
+
+ else:
+ # Step 2: check ignore boxes (always with fixed dist_th).
+ ignores = ignore_by_token.get(pred_box.sample_token, [])
+ is_ignored = any(dist_fcn(ign, pred_box) < dist_th for ign in ignores)
+
+ if is_ignored:
+ continue
+
+ tp.append(0)
+ fp.append(1)
+ conf.append(pred_box.detection_score)
+
+ if len(match_data['trans_err']) == 0:
+ return DetectionMetricData.no_predictions()
+
+ tp = np.cumsum(tp).astype(float)
+ fp = np.cumsum(fp).astype(float)
+ conf = np.array(conf)
+
+ prec = tp / (fp + tp)
+ rec = tp / float(npos)
+
+ rec_interp = np.linspace(0, 1, DetectionMetricData.nelem)
+ prec = np.interp(rec_interp, rec, prec, right=0)
+ conf = np.interp(rec_interp, rec, conf, right=0)
+ rec = rec_interp
+
+ for key in match_data:
+ if key == 'conf':
+ continue
+ tmp = cummean(np.array(match_data[key]))
+ match_data[key] = np.interp(
+ conf[::-1], match_data['conf'][::-1], tmp[::-1]
+ )[::-1]
+
+ return DetectionMetricData(
+ recall=rec, precision=prec, confidence=conf,
+ trans_err=match_data['trans_err'], vel_err=match_data['vel_err'],
+ scale_err=match_data['scale_err'], orient_err=match_data['orient_err'],
+ attr_err=match_data['attr_err'],
+ )
+
+
+# ---------------------------------------------------------------------------
+# Base evaluator
+# ---------------------------------------------------------------------------
+
+class _IgnoreAwareNuScenesEval(NuScenesEval):
+ """Base class for ignore-aware detection evaluators."""
+
+ def _filter_occluded_boxes(self, gt_boxes: EvalBoxes) -> EvalBoxes:
+ filtered = EvalBoxes()
+ for sample_token in gt_boxes.sample_tokens:
+ boxes = [
+ box for box in gt_boxes[sample_token]
+ if box.num_pts == 0
+ and box.detection_name in self.cfg.class_range
+ and box.ego_dist < self.cfg.class_range[box.detection_name]
+ ]
+ filtered.add_boxes(sample_token, boxes)
+ return filtered
+
+ def _filter_visible_boxes(self, gt_boxes: EvalBoxes) -> EvalBoxes:
+ filtered = EvalBoxes()
+ for sample_token in gt_boxes.sample_tokens:
+ boxes = [
+ box for box in gt_boxes[sample_token]
+ if box.num_pts > 0
+ and box.detection_name in self.cfg.class_range
+ and box.ego_dist < self.cfg.class_range[box.detection_name]
+ ]
+ filtered.add_boxes(sample_token, boxes)
+ return filtered
+
+ def evaluate(self):
+ import time
+ from nuscenes.eval.detection.algo import calc_ap, calc_tp
+ from nuscenes.eval.detection.data_classes import (
+ DetectionMetrics, DetectionMetricDataList,
+ )
+ from nuscenes.eval.detection.constants import TP_METRICS
+
+ start_time = time.time()
+ if self.verbose:
+ print('Accumulating metric data...')
+
+ metric_data_list = DetectionMetricDataList()
+ for class_name in self.cfg.class_names:
+ for dist_th in self.cfg.dist_ths:
+ md = accumulate_with_ignore(
+ self.gt_boxes,
+ self.pred_boxes,
+ self._ignore_boxes,
+ class_name,
+ self.cfg.dist_fcn_callable,
+ dist_th,
+ )
+ metric_data_list.set(class_name, dist_th, md)
+
+ if self.verbose:
+ print('Calculating metrics...')
+
+ metrics = DetectionMetrics(self.cfg)
+ for class_name in self.cfg.class_names:
+ for dist_th in self.cfg.dist_ths:
+ metric_data = metric_data_list[(class_name, dist_th)]
+ ap = calc_ap(metric_data, self.cfg.min_recall, self.cfg.min_precision)
+ metrics.add_label_ap(class_name, dist_th, ap)
+
+ for metric_name in TP_METRICS:
+ metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)]
+ if class_name == 'traffic_cone' and metric_name in ('attr_err', 'vel_err', 'orient_err'):
+ tp = np.nan
+ elif class_name == 'barrier' and metric_name in ('attr_err', 'vel_err'):
+ tp = np.nan
+ else:
+ tp = calc_tp(metric_data, self.cfg.min_recall, metric_name)
+ metrics.add_label_tp(class_name, metric_name, tp)
+
+ metrics.add_runtime(time.time() - start_time)
+ return metrics, metric_data_list
+
+
+# ---------------------------------------------------------------------------
+# Concrete evaluators
+# ---------------------------------------------------------------------------
+
+class OccludedDetectionEval(_IgnoreAwareNuScenesEval):
+ """NuScenes detection evaluator restricted to occluded objects.
+
+ Uses an adaptive matching threshold d = dist_th + alpha*t + beta*v*t +
+ gamma*a*t^2 where t is occlusion duration, v is speed and a is
+ acceleration at the last visible annotation. Coefficients are
+ class-specific (Vehicle / Cyclist / Pedestrian via UniTraj mapping).
+ Static classes (barrier, traffic_cone) retain the fixed dist_th.
+ """
+
+ def __init__(self, nusc, config, result_path, eval_set, output_dir, verbose):
+ super().__init__(nusc, config, result_path, eval_set, output_dir, verbose)
+
+ all_gt = load_gt(nusc, eval_set, DetectionBox, verbose=verbose)
+ all_gt = add_center_dist(nusc, all_gt)
+
+ self.gt_boxes = self._filter_occluded_boxes(all_gt)
+ self.sample_tokens = self.gt_boxes.sample_tokens
+ self._ignore_boxes = self._filter_visible_boxes(all_gt)
+
+ occ_total = sum(len(self.gt_boxes[t]) for t in self.gt_boxes.sample_tokens)
+ class_counts = Counter(
+ box.detection_name
+ for t in self.gt_boxes.sample_tokens
+ for box in self.gt_boxes[t]
+ )
+ vis_total = sum(len(self._ignore_boxes[t]) for t in self._ignore_boxes.sample_tokens)
+ print(f'[Occluded Det] GT occluded boxes: {occ_total} | {dict(class_counts)}')
+ print(f'[Occluded Det] Ignore (visible) boxes: {vis_total}')
+
+ # Pre-compute adaptive thresholds for each (base_dist_th).
+ self._adaptive_dist_ths = self._build_adaptive_dist_ths(nusc)
+
+ def _build_adaptive_dist_ths(
+ self, nusc
+ ) -> Dict[float, Dict[Tuple[str, int], float]]:
+ """Build per-GT-box adaptive thresholds for every base dist_th.
+
+ Returns
+ -------
+ dict mapping base_dist_th → {(sample_token, gt_idx): adaptive_dist_th}
+ """
+ # Build sample_token → {rounded_translation: ann_token} for fast lookup.
+ sample_ann_map: Dict[str, Dict[tuple, str]] = {}
+ for sample_token in self.gt_boxes.sample_tokens:
+ sample = nusc.get('sample', sample_token)
+ pos_to_tok = {}
+ for ann_token in sample['anns']:
+ ann = nusc.get('sample_annotation', ann_token)
+ key = (round(ann['translation'][0], 2),
+ round(ann['translation'][1], 2))
+ pos_to_tok[key] = ann_token
+ sample_ann_map[sample_token] = pos_to_tok
+
+ # Compute (t, v, a) and cache adaptive threshold per box per dist_th.
+ result: Dict[float, Dict[Tuple[str, int], float]] = {
+ d: {} for d in self.cfg.dist_ths
+ }
+
+ for sample_token in self.gt_boxes.sample_tokens:
+ pos_to_tok = sample_ann_map[sample_token]
+ for gt_idx, box in enumerate(self.gt_boxes[sample_token]):
+ key_pos = (round(box.translation[0], 2),
+ round(box.translation[1], 2))
+ ann_token = pos_to_tok.get(key_pos)
+ if ann_token is None:
+ continue # fallback: keep base dist_th (entry absent → default)
+
+ t, v, a = _occ_metadata(nusc, ann_token)
+ for base_dist_th in self.cfg.dist_ths:
+ adaptive = _adaptive_dist_th(base_dist_th, box.detection_name,
+ t, v, a)
+ result[base_dist_th][(sample_token, gt_idx)] = adaptive
+
+ return result
+
+ def evaluate(self):
+ """Override to pass adaptive per-GT thresholds to accumulate_with_ignore."""
+ import time
+ from nuscenes.eval.detection.algo import calc_ap, calc_tp
+ from nuscenes.eval.detection.data_classes import (
+ DetectionMetrics, DetectionMetricDataList,
+ )
+ from nuscenes.eval.detection.constants import TP_METRICS
+
+ start_time = time.time()
+ if self.verbose:
+ print('Accumulating metric data (adaptive thresholds)...')
+
+ metric_data_list = DetectionMetricDataList()
+ for class_name in self.cfg.class_names:
+ for dist_th in self.cfg.dist_ths:
+ md = accumulate_with_ignore(
+ self.gt_boxes,
+ self.pred_boxes,
+ self._ignore_boxes,
+ class_name,
+ self.cfg.dist_fcn_callable,
+ dist_th,
+ per_gt_dist_ths=self._adaptive_dist_ths.get(dist_th),
+ )
+ metric_data_list.set(class_name, dist_th, md)
+
+ if self.verbose:
+ print('Calculating metrics...')
+
+ metrics = DetectionMetrics(self.cfg)
+ for class_name in self.cfg.class_names:
+ for dist_th in self.cfg.dist_ths:
+ metric_data = metric_data_list[(class_name, dist_th)]
+ ap = calc_ap(metric_data, self.cfg.min_recall, self.cfg.min_precision)
+ metrics.add_label_ap(class_name, dist_th, ap)
+
+ for metric_name in TP_METRICS:
+ metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)]
+ if class_name == 'traffic_cone' and metric_name in ('attr_err', 'vel_err', 'orient_err'):
+ tp = np.nan
+ elif class_name == 'barrier' and metric_name in ('attr_err', 'vel_err'):
+ tp = np.nan
+ else:
+ tp = calc_tp(metric_data, self.cfg.min_recall, metric_name)
+ metrics.add_label_tp(class_name, metric_name, tp)
+
+ metrics.add_runtime(time.time() - start_time)
+ return metrics, metric_data_list
+
+
+class VisibleDetectionEval(_IgnoreAwareNuScenesEval):
+ """NuScenes detection evaluator on visible objects (num_lidar_pts >= 1)."""
+
+ def __init__(self, nusc, config, result_path, eval_set, output_dir, verbose):
+ super().__init__(nusc, config, result_path, eval_set, output_dir, verbose)
+
+ all_gt = load_gt(nusc, eval_set, DetectionBox, verbose=verbose)
+ all_gt = add_center_dist(nusc, all_gt)
+ self._ignore_boxes = self._filter_occluded_boxes(all_gt)
+
+ vis_total = sum(len(self.gt_boxes[t]) for t in self.gt_boxes.sample_tokens)
+ class_counts = Counter(
+ box.detection_name
+ for t in self.gt_boxes.sample_tokens
+ for box in self.gt_boxes[t]
+ )
+ occ_total = sum(len(self._ignore_boxes[t]) for t in self._ignore_boxes.sample_tokens)
+ print(f'[Visible Det] GT visible boxes: {vis_total} | {dict(class_counts)}')
+ print(f'[Visible Det] Ignore (occluded) boxes: {occ_total}')
+
+
+class AllDetectionEval(NuScenesEval):
+ """NuScenes detection evaluator on all objects (visible + occluded)."""
+
+ def __init__(self, nusc, config, result_path, eval_set, output_dir, verbose):
+ super().__init__(nusc, config, result_path, eval_set, output_dir, verbose)
+
+ self.gt_boxes = load_gt(nusc, eval_set, DetectionBox, verbose=verbose)
+ self.gt_boxes = add_center_dist(nusc, self.gt_boxes)
+ self.gt_boxes = self._filter_all_gt(self.gt_boxes)
+ self.sample_tokens = self.gt_boxes.sample_tokens
+
+ def _filter_all_gt(self, gt_boxes: EvalBoxes) -> EvalBoxes:
+ filtered = EvalBoxes()
+ for sample_token in gt_boxes.sample_tokens:
+ boxes = [
+ box for box in gt_boxes[sample_token]
+ if box.detection_name in self.cfg.class_range
+ and box.ego_dist < self.cfg.class_range[box.detection_name]
+ ]
+ filtered.add_boxes(sample_token, boxes)
+ total = sum(len(filtered[t]) for t in filtered.sample_tokens)
+ class_counts = Counter(
+ box.detection_name
+ for t in filtered.sample_tokens
+ for box in filtered[t]
+ )
+ print(f'[All Det] GT all boxes: {total} | {dict(class_counts)}')
+ return filtered
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/map/AP.py b/projects/mmdet3d_plugin/datasets/evaluation/map/AP.py
new file mode 100644
index 0000000..0be3480
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/evaluation/map/AP.py
@@ -0,0 +1,136 @@
+import numpy as np
+from .distance import chamfer_distance, frechet_distance, chamfer_distance_batch
+from typing import List, Tuple, Union
+from numpy.typing import NDArray
+
+def average_precision(recalls, precisions, mode='area'):
+ """Calculate average precision.
+
+ Args:
+ recalls (ndarray): shape (num_dets, )
+ precisions (ndarray): shape (num_dets, )
+ mode (str): 'area' or '11points', 'area' means calculating the area
+ under precision-recall curve, '11points' means calculating
+ the average precision of recalls at [0, 0.1, ..., 1]
+
+ Returns:
+ float: calculated average precision
+ """
+
+ recalls = recalls[np.newaxis, :]
+ precisions = precisions[np.newaxis, :]
+
+ assert recalls.shape == precisions.shape and recalls.ndim == 2
+ num_scales = recalls.shape[0]
+ ap = 0.
+
+ if mode == 'area':
+ zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
+ ones = np.ones((num_scales, 1), dtype=recalls.dtype)
+ mrec = np.hstack((zeros, recalls, ones))
+ mpre = np.hstack((zeros, precisions, zeros))
+ for i in range(mpre.shape[1] - 1, 0, -1):
+ mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
+
+ ind = np.where(mrec[0, 1:] != mrec[0, :-1])[0]
+ ap = np.sum(
+ (mrec[0, ind + 1] - mrec[0, ind]) * mpre[0, ind + 1])
+
+ elif mode == '11points':
+ for thr in np.arange(0, 1 + 1e-3, 0.1):
+ precs = precisions[0, recalls[i, :] >= thr]
+ prec = precs.max() if precs.size > 0 else 0
+ ap += prec
+ ap /= 11
+ else:
+ raise ValueError(
+ 'Unrecognized mode, only "area" and "11points" are supported')
+
+ return ap
+
+def instance_match(pred_lines: NDArray,
+ scores: NDArray,
+ gt_lines: NDArray,
+ thresholds: Union[Tuple, List],
+ metric: str='chamfer') -> List:
+ """Compute whether detected lines are true positive or false positive.
+
+ Args:
+ pred_lines (array): Detected lines of a sample, of shape (M, INTERP_NUM, 2 or 3).
+ scores (array): Confidence score of each line, of shape (M, ).
+ gt_lines (array): GT lines of a sample, of shape (N, INTERP_NUM, 2 or 3).
+ thresholds (list of tuple): List of thresholds.
+ metric (str): Distance function for lines matching. Default: 'chamfer'.
+
+ Returns:
+ list_of_tp_fp (list): tp-fp matching result at all thresholds
+ """
+
+ if metric == 'chamfer':
+ distance_fn = chamfer_distance
+
+ elif metric == 'frechet':
+ distance_fn = frechet_distance
+
+ else:
+ raise ValueError(f'unknown distance function {metric}')
+
+ num_preds = pred_lines.shape[0]
+ num_gts = gt_lines.shape[0]
+
+ # tp and fp
+ tp_fp_list = []
+ tp = np.zeros((num_preds), dtype=np.float32)
+ fp = np.zeros((num_preds), dtype=np.float32)
+
+ # if there is no gt lines in this sample, then all pred lines are false positives
+ if num_gts == 0:
+ fp[...] = 1
+ for thr in thresholds:
+ tp_fp_list.append((tp.copy(), fp.copy()))
+ return tp_fp_list
+
+ if num_preds == 0:
+ for thr in thresholds:
+ tp_fp_list.append((tp.copy(), fp.copy()))
+ return tp_fp_list
+
+ assert pred_lines.shape[1] == gt_lines.shape[1], \
+ "sample points num should be the same"
+
+ # distance matrix: M x N
+ matrix = np.zeros((num_preds, num_gts))
+
+ # for i in range(num_preds):
+ # for j in range(num_gts):
+ # matrix[i, j] = distance_fn(pred_lines[i], gt_lines[j])
+
+ matrix = chamfer_distance_batch(pred_lines, gt_lines)
+ # for each det, the min distance with all gts
+ matrix_min = matrix.min(axis=1)
+
+ # for each det, which gt is the closest to it
+ matrix_argmin = matrix.argmin(axis=1)
+ # sort all dets in descending order by scores
+ sort_inds = np.argsort(-scores)
+
+ # match under different thresholds
+ for thr in thresholds:
+ tp = np.zeros((num_preds), dtype=np.float32)
+ fp = np.zeros((num_preds), dtype=np.float32)
+
+ gt_covered = np.zeros(num_gts, dtype=bool)
+ for i in sort_inds:
+ if matrix_min[i] <= thr:
+ matched_gt = matrix_argmin[i]
+ if not gt_covered[matched_gt]:
+ gt_covered[matched_gt] = True
+ tp[i] = 1
+ else:
+ fp[i] = 1
+ else:
+ fp[i] = 1
+
+ tp_fp_list.append((tp, fp))
+
+ return tp_fp_list
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/map/distance.py b/projects/mmdet3d_plugin/datasets/evaluation/map/distance.py
new file mode 100644
index 0000000..0152755
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/evaluation/map/distance.py
@@ -0,0 +1,67 @@
+from scipy.spatial import distance
+from numpy.typing import NDArray
+import torch
+
+def chamfer_distance(line1: NDArray, line2: NDArray) -> float:
+ ''' Calculate chamfer distance between two lines. Make sure the
+ lines are interpolated.
+
+ Args:
+ line1 (array): coordinates of line1
+ line2 (array): coordinates of line2
+
+ Returns:
+ distance (float): chamfer distance
+ '''
+
+ dist_matrix = distance.cdist(line1, line2, 'euclidean')
+ dist12 = dist_matrix.min(-1).sum() / len(line1)
+ dist21 = dist_matrix.min(-2).sum() / len(line2)
+
+ return (dist12 + dist21) / 2
+
+def frechet_distance(line1: NDArray, line2: NDArray) -> float:
+ ''' Calculate frechet distance between two lines. Make sure the
+ lines are interpolated.
+
+ Args:
+ line1 (array): coordinates of line1
+ line2 (array): coordinates of line2
+
+ Returns:
+ distance (float): frechet distance
+ '''
+
+ raise NotImplementedError
+
+def chamfer_distance_batch(pred_lines, gt_lines):
+ ''' Calculate chamfer distance between two group of lines. Make sure the
+ lines are interpolated.
+
+ Args:
+ pred_lines (array or tensor): shape (m, num_pts, 2 or 3)
+ gt_lines (array or tensor): shape (n, num_pts, 2 or 3)
+
+ Returns:
+ distance (array): chamfer distance
+ '''
+ _, num_pts, coord_dims = pred_lines.shape
+
+ if not isinstance(pred_lines, torch.Tensor):
+ pred_lines = torch.tensor(pred_lines)
+ if not isinstance(gt_lines, torch.Tensor):
+ gt_lines = torch.tensor(gt_lines)
+ dist_mat = torch.cdist(pred_lines.view(-1, coord_dims),
+ gt_lines.view(-1, coord_dims), p=2)
+ # (num_query*num_points, num_gt*num_points)
+ dist_mat = torch.stack(torch.split(dist_mat, num_pts))
+ # (num_query, num_points, num_gt*num_points)
+ dist_mat = torch.stack(torch.split(dist_mat, num_pts, dim=-1))
+ # (num_gt, num_q, num_pts, num_pts)
+
+ dist1 = dist_mat.min(-1)[0].sum(-1)
+ dist2 = dist_mat.min(-2)[0].sum(-1)
+
+ dist_matrix = (dist1 + dist2).transpose(0, 1) / (2 * num_pts)
+
+ return dist_matrix.numpy()
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/map/vector_eval.py b/projects/mmdet3d_plugin/datasets/evaluation/map/vector_eval.py
new file mode 100644
index 0000000..de7e83b
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/evaluation/map/vector_eval.py
@@ -0,0 +1,279 @@
+import prettytable
+from typing import Dict, List, Optional
+from time import time
+from copy import deepcopy
+from multiprocessing import Pool
+from logging import Logger
+from functools import partial, cached_property
+
+import numpy as np
+from numpy.typing import NDArray
+from shapely.geometry import LineString
+
+import mmcv
+from mmcv import Config
+from mmdet.datasets import build_dataset, build_dataloader
+
+from .AP import instance_match, average_precision
+
+INTERP_NUM = 200 # number of points to interpolate during evaluation
+THRESHOLDS = [0.5, 1.0, 1.5] # AP thresholds
+N_WORKERS = 8 # num workers to parallel
+
+class VectorEvaluate(object):
+ """Evaluator for vectorized map.
+
+ Args:
+ dataset_cfg (Config): dataset cfg for gt
+ n_workers (int): num workers to parallel
+ """
+
+ def __init__(self, dataset_cfg: Config, n_workers: int=N_WORKERS) -> None:
+ self.dataset = build_dataset(dataset_cfg)
+ self.dataloader = build_dataloader(
+ self.dataset, samples_per_gpu=1, workers_per_gpu=n_workers, shuffle=False, dist=False)
+ classes = self.dataset.MAP_CLASSES
+ self.cat2id = {cls: i for i, cls in enumerate(classes)}
+ self.id2cat = {v: k for k, v in self.cat2id.items()}
+ self.n_workers = n_workers
+ self.thresholds = [0.5, 1.0, 1.5]
+
+ @cached_property
+ def gts(self) -> Dict[str, Dict[int, List[NDArray]]]:
+ print('collecting gts...')
+ gts = {}
+ pbar = mmcv.ProgressBar(len(self.dataloader))
+ for data in self.dataloader:
+ token = deepcopy(data['img_metas'].data[0][0]['token'])
+ gt = deepcopy(data['vectors'].data[0][0])
+ gts[token] = gt
+ pbar.update()
+ del data # avoid dataloader memory crash
+
+ return gts
+
+ def interp_fixed_num(self,
+ vector: NDArray,
+ num_pts: int) -> NDArray:
+ ''' Interpolate a polyline.
+
+ Args:
+ vector (array): line coordinates, shape (M, 2)
+ num_pts (int):
+
+ Returns:
+ sampled_points (array): interpolated coordinates
+ '''
+ line = LineString(vector)
+ distances = np.linspace(0, line.length, num_pts)
+ sampled_points = np.array([list(line.interpolate(distance).coords)
+ for distance in distances]).squeeze()
+
+ return sampled_points
+
+ def interp_fixed_dist(self,
+ vector: NDArray,
+ sample_dist: float) -> NDArray:
+ ''' Interpolate a line at fixed interval.
+
+ Args:
+ vector (LineString): vector
+ sample_dist (float): sample interval
+
+ Returns:
+ points (array): interpolated points, shape (N, 2)
+ '''
+ line = LineString(vector)
+ distances = list(np.arange(sample_dist, line.length, sample_dist))
+ # make sure to sample at least two points when sample_dist > line.length
+ distances = [0,] + distances + [line.length,]
+
+ sampled_points = np.array([list(line.interpolate(distance).coords)
+ for distance in distances]).squeeze()
+
+ return sampled_points
+
+ def _evaluate_single(self,
+ pred_vectors: List,
+ scores: List,
+ groundtruth: List,
+ thresholds: List,
+ metric: str='metric') -> Dict[int, NDArray]:
+ ''' Do single-frame matching for one class.
+
+ Args:
+ pred_vectors (List): List[vector(ndarray) (different length)],
+ scores (List): List[score(float)]
+ groundtruth (List): List of vectors
+ thresholds (List): List of thresholds
+
+ Returns:
+ tp_fp_score_by_thr (Dict): matching results at different thresholds
+ e.g. {0.5: (M, 2), 1.0: (M, 2), 1.5: (M, 2)}
+ '''
+ pred_lines = []
+
+ # interpolate predictions
+ for vector in pred_vectors:
+ vector = np.array(vector)
+ vector_interp = self.interp_fixed_num(vector, INTERP_NUM)
+ pred_lines.append(vector_interp)
+ if pred_lines:
+ pred_lines = np.stack(pred_lines)
+ else:
+ pred_lines = np.zeros((0, INTERP_NUM, 2))
+
+ # interpolate groundtruth
+ gt_lines = []
+ for vector in groundtruth:
+ vector_interp = self.interp_fixed_num(vector, INTERP_NUM)
+ gt_lines.append(vector_interp)
+ if gt_lines:
+ gt_lines = np.stack(gt_lines)
+ else:
+ gt_lines = np.zeros((0, INTERP_NUM, 2))
+
+ scores = np.array(scores)
+ tp_fp_list = instance_match(pred_lines, scores, gt_lines, thresholds, metric) # (M, 2)
+ tp_fp_score_by_thr = {}
+ for i, thr in enumerate(thresholds):
+ tp, fp = tp_fp_list[i]
+ tp_fp_score = np.hstack([tp[:, None], fp[:, None], scores[:, None]])
+ tp_fp_score_by_thr[thr] = tp_fp_score
+
+ return tp_fp_score_by_thr # {0.5: (M, 2), 1.0: (M, 2), 1.5: (M, 2)}
+
+ def evaluate(self,
+ result_path: str,
+ metric: str='chamfer',
+ logger: Optional[Logger]=None) -> Dict[str, float]:
+ ''' Do evaluation for a submission file and print evalution results to `logger` if specified.
+ The submission will be aligned by tokens before evaluation. We use multi-worker to speed up.
+
+ Args:
+ result_path (str): path to submission file
+ metric (str): distance metric. Default: 'chamfer'
+ logger (Logger): logger to print evaluation result, Default: None
+
+ Returns:
+ new_result_dict (Dict): evaluation results. AP by categories.
+ '''
+ results = mmcv.load(result_path)
+ results = results['results']
+
+ # re-group samples and gt by label
+ samples_by_cls = {label: [] for label in self.id2cat.keys()}
+ num_gts = {label: 0 for label in self.id2cat.keys()}
+ num_preds = {label: 0 for label in self.id2cat.keys()}
+
+ # align by token
+ for token, gt in self.gts.items():
+ if token in results.keys():
+ pred = results[token]
+ else:
+ pred = {'vectors': [], 'scores': [], 'labels': []}
+
+ # for every sample
+ vectors_by_cls = {label: [] for label in self.id2cat.keys()}
+ scores_by_cls = {label: [] for label in self.id2cat.keys()}
+
+ for i in range(len(pred['labels'])):
+ # i-th pred line in sample
+ label = pred['labels'][i]
+ vector = pred['vectors'][i]
+ score = pred['scores'][i]
+
+ vectors_by_cls[label].append(vector)
+ scores_by_cls[label].append(score)
+
+ for label in self.id2cat.keys():
+ new_sample = (vectors_by_cls[label], scores_by_cls[label], gt[label])
+ num_gts[label] += len(gt[label])
+ num_preds[label] += len(scores_by_cls[label])
+ samples_by_cls[label].append(new_sample)
+
+ result_dict = {}
+
+ print(f'\nevaluating {len(self.id2cat)} categories...')
+ start = time()
+ if self.n_workers > 0:
+ pool = Pool(self.n_workers)
+
+ sum_mAP = 0
+ pbar = mmcv.ProgressBar(len(self.id2cat))
+ for label in self.id2cat.keys():
+ samples = samples_by_cls[label] # List[(pred_lines, scores, gts)]
+ result_dict[self.id2cat[label]] = {
+ 'num_gts': num_gts[label],
+ 'num_preds': num_preds[label]
+ }
+ sum_AP = 0
+
+ fn = partial(self._evaluate_single, thresholds=self.thresholds, metric=metric)
+ if self.n_workers > 0 and len(samples) > 81:
+ tpfp_score_list = pool.starmap(fn, samples)
+ else:
+ tpfp_score_list = []
+ for sample in samples:
+ tpfp_score_list.append(fn(*sample))
+
+ for thr in self.thresholds:
+ tp_fp_score = [i[thr] for i in tpfp_score_list]
+ tp_fp_score = np.vstack(tp_fp_score) # (num_dets, 3)
+ sort_inds = np.argsort(-tp_fp_score[:, -1])
+
+ tp = tp_fp_score[sort_inds, 0] # (num_dets,)
+ fp = tp_fp_score[sort_inds, 1] # (num_dets,)
+ tp = np.cumsum(tp, axis=0)
+ fp = np.cumsum(fp, axis=0)
+ eps = np.finfo(np.float32).eps
+ recalls = tp / np.maximum(num_gts[label], eps)
+ precisions = tp / np.maximum((tp + fp), eps)
+
+ AP = average_precision(recalls, precisions, 'area')
+ sum_AP += AP
+ result_dict[self.id2cat[label]].update({f'AP@{thr}': AP})
+
+ pbar.update()
+
+ AP = sum_AP / len(self.thresholds)
+ sum_mAP += AP
+
+ result_dict[self.id2cat[label]].update({f'AP': AP})
+
+ if self.n_workers > 0:
+ pool.close()
+
+ mAP = sum_mAP / len(self.id2cat.keys())
+ result_dict.update({'mAP': mAP})
+
+ print(f"finished in {time() - start:.2f}s")
+
+ # print results
+ table = prettytable.PrettyTable(['category', 'num_preds', 'num_gts'] +
+ [f'AP@{thr}' for thr in self.thresholds] + ['AP'])
+ for label in self.id2cat.keys():
+ table.add_row([
+ self.id2cat[label],
+ result_dict[self.id2cat[label]]['num_preds'],
+ result_dict[self.id2cat[label]]['num_gts'],
+ *[round(result_dict[self.id2cat[label]][f'AP@{thr}'], 4) for thr in self.thresholds],
+ round(result_dict[self.id2cat[label]]['AP'], 4),
+ ])
+
+ from mmcv.utils import print_log
+ print_log('\n'+str(table), logger=logger)
+ mAP_normal = 0
+ for label in self.id2cat.keys():
+ for thr in self.thresholds:
+ mAP_normal += result_dict[self.id2cat[label]][f'AP@{thr}']
+ mAP_normal = mAP_normal / 9
+
+ print_log(f'mAP_normal = {mAP_normal:.4f}\n', logger=logger)
+ # print_log(f'mAP_hard = {mAP_easy:.4f}\n', logger=logger)
+
+ new_result_dict = {}
+ for name in self.cat2id:
+ new_result_dict[name] = result_dict[name]['AP']
+ new_result_dict['mAP_normal'] = mAP_normal
+ return new_result_dict
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_eval_uniad.py b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_eval_uniad.py
new file mode 100644
index 0000000..c1f27db
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_eval_uniad.py
@@ -0,0 +1,417 @@
+# nuScenes dev-kit.
+# Code written by Holger Caesar & Oscar Beijbom, 2018.
+
+import argparse
+import json
+import os
+import random
+import time
+import tqdm
+from typing import Tuple, Dict, Any
+
+import numpy as np
+
+from nuscenes import NuScenes
+from nuscenes.eval.common.config import config_factory
+from nuscenes.eval.common.data_classes import EvalBoxes
+from nuscenes.eval.common.loaders import add_center_dist, filter_eval_boxes
+from nuscenes.eval.detection.algo import accumulate, calc_ap, calc_tp
+from nuscenes.eval.detection.constants import DETECTION_NAMES, ATTRIBUTE_NAMES, TP_METRICS
+from nuscenes.eval.detection.data_classes import DetectionConfig, DetectionMetrics, DetectionBox, \
+ DetectionMetricDataList, DetectionMetricData
+from nuscenes.eval.detection.render import summary_plot, class_pr_curve, class_tp_curve, dist_pr_curve, visualize_sample
+from nuscenes.prediction import PredictHelper, convert_local_coords_to_global
+from nuscenes.utils.splits import create_splits_scenes
+from nuscenes.eval.detection.utils import category_to_detection_name
+from nuscenes.eval.common.utils import quaternion_yaw, Quaternion
+from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean
+
+from .motion_utils import MotionBox, load_prediction, load_gt, accumulate
+MOTION_TP_METRICS = ['min_ade_err', 'min_fde_err', 'miss_rate_err', 'top1_fde_err', 'brier_min_fde_err']
+
+
+class MotionEval:
+ """
+ This is the official nuScenes detection evaluation code.
+ Results are written to the provided output_dir.
+
+ nuScenes uses the following detection metrics:
+ - Mean Average Precision (mAP): Uses center-distance as matching criterion; averaged over distance thresholds.
+ - True Positive (TP) metrics: Average of translation, velocity, scale, orientation and attribute errors.
+ - nuScenes Detection Score (NDS): The weighted sum of the above.
+
+ Here is an overview of the functions in this method:
+ - init: Loads GT annotations and predictions stored in JSON format and filters the boxes.
+ - run: Performs evaluation and dumps the metric data to disk.
+ - render: Renders various plots and dumps to disk.
+
+ We assume that:
+ - Every sample_token is given in the results, although there may be not predictions for that sample.
+
+ Please see https://www.nuscenes.org/object-detection for more details.
+ """
+ def __init__(self,
+ nusc: NuScenes,
+ config: DetectionConfig,
+ result_path: str,
+ eval_set: str,
+ output_dir: str = None,
+ verbose: bool = True,
+ seconds: int = 12):
+ """
+ Initialize a DetectionEval object.
+ :param nusc: A NuScenes object.
+ :param config: A DetectionConfig object.
+ :param result_path: Path of the nuScenes JSON result file.
+ :param eval_set: The dataset split to evaluate on, e.g. train, val or test.
+ :param output_dir: Folder to save plots and results to.
+ :param verbose: Whether to print to stdout.
+ """
+ self.nusc = nusc
+ self.result_path = result_path
+ self.eval_set = eval_set
+ self.output_dir = output_dir
+ self.verbose = verbose
+ self.cfg = config
+
+ # Check result file exists.
+ # assert os.path.exists(result_path), 'Error: The result file does not exist!'
+
+ # Make dirs.
+ self.plot_dir = os.path.join(self.output_dir, 'plots')
+ if not os.path.isdir(self.output_dir):
+ os.makedirs(self.output_dir)
+ if not os.path.isdir(self.plot_dir):
+ os.makedirs(self.plot_dir)
+
+ # Load data.
+ if verbose:
+ print('Initializing nuScenes detection evaluation')
+ self.pred_boxes, self.meta = load_prediction(self.result_path, self.cfg.max_boxes_per_sample, MotionBox,
+ verbose=verbose)
+ self.gt_boxes = load_gt(self.nusc, self.eval_set, MotionBox, verbose=verbose, seconds=seconds)
+
+ assert set(self.pred_boxes.sample_tokens) == set(self.gt_boxes.sample_tokens), \
+ "Samples in split doesn't match samples in predictions."
+
+ # Add center distances.
+ self.pred_boxes = add_center_dist(nusc, self.pred_boxes)
+ self.gt_boxes = add_center_dist(nusc, self.gt_boxes)
+
+ # Filter boxes (distance, points per box, etc.).
+ if verbose:
+ print('Filtering predictions')
+ self.pred_boxes = filter_eval_boxes(nusc, self.pred_boxes, self.cfg.class_range, verbose=verbose)
+ if verbose:
+ print('Filtering ground truth annotations')
+ self.gt_boxes = filter_eval_boxes(nusc, self.gt_boxes, self.cfg.class_range, verbose=verbose)
+
+ self.sample_tokens = self.gt_boxes.sample_tokens
+
+ def evaluate(self) -> Tuple[DetectionMetrics, DetectionMetricDataList]:
+ """
+ Performs the actual evaluation.
+ :return: A tuple of high-level and the raw metric data.
+ """
+ start_time = time.time()
+ self.cfg.class_names = ['car', 'pedestrian']
+ self.cfg.dist_ths = [2.0]
+
+ # -----------------------------------
+ # Step 1: Accumulate metric data for all classes and distance thresholds.
+ # -----------------------------------
+ if self.verbose:
+ print('Accumulating metric data...')
+ metric_data_list = DetectionMetricDataList()
+ metrics = {}
+ for class_name in self.cfg.class_names:
+ for dist_th in self.cfg.dist_ths:
+ md, EPA, EPA_ = accumulate(self.gt_boxes, self.pred_boxes, class_name, self.cfg.dist_fcn_callable, dist_th)
+ metric_data_list.set(class_name, dist_th, md)
+ metrics[f'{class_name}_EPA'] = EPA_
+
+ # -----------------------------------
+ # Step 2: Calculate metrics from the data.
+ # -----------------------------------
+ if self.verbose:
+ print('Calculating metrics...')
+ for class_name in self.cfg.class_names:
+ # Compute TP metrics.
+ for metric_name in MOTION_TP_METRICS:
+ metric_data = metric_data_list[(class_name, self.cfg.dist_th_tp)]
+ tp = calc_tp(metric_data, self.cfg.min_recall, metric_name)
+ metrics[f'{class_name}_{metric_name}'] = tp
+
+ return metrics, metric_data_list
+
+ def render(self, metrics: DetectionMetrics, md_list: DetectionMetricDataList) -> None:
+ """
+ Renders various PR and TP curves.
+ :param metrics: DetectionMetrics instance.
+ :param md_list: DetectionMetricDataList instance.
+ """
+ if self.verbose:
+ print('Rendering PR and TP curves')
+
+ def savepath(name):
+ return os.path.join(self.plot_dir, name + '.pdf')
+
+ summary_plot(md_list, metrics, min_precision=self.cfg.min_precision, min_recall=self.cfg.min_recall,
+ dist_th_tp=self.cfg.dist_th_tp, savepath=savepath('summary'))
+
+ for detection_name in self.cfg.class_names:
+ class_pr_curve(md_list, metrics, detection_name, self.cfg.min_precision, self.cfg.min_recall,
+ savepath=savepath(detection_name + '_pr'))
+
+ class_tp_curve(md_list, metrics, detection_name, self.cfg.min_recall, self.cfg.dist_th_tp,
+ savepath=savepath(detection_name + '_tp'))
+
+ for dist_th in self.cfg.dist_ths:
+ dist_pr_curve(md_list, metrics, dist_th, self.cfg.min_precision, self.cfg.min_recall,
+ savepath=savepath('dist_pr_' + str(dist_th)))
+
+ def main(self,
+ plot_examples: int = 0,
+ render_curves: bool = True) -> Dict[str, Any]:
+ """
+ Main function that loads the evaluation code, visualizes samples, runs the evaluation and renders stat plots.
+ :param plot_examples: How many example visualizations to write to disk.
+ :param render_curves: Whether to render PR and TP curves to disk.
+ :return: A dict that stores the high-level metrics and meta data.
+ """
+ if plot_examples > 0:
+ # Select a random but fixed subset to plot.
+ random.seed(42)
+ sample_tokens = list(self.sample_tokens)
+ random.shuffle(sample_tokens)
+ sample_tokens = sample_tokens[:plot_examples]
+
+ # Visualize samples.
+ example_dir = os.path.join(self.output_dir, 'examples')
+ if not os.path.isdir(example_dir):
+ os.mkdir(example_dir)
+ for sample_token in sample_tokens:
+ visualize_sample(self.nusc,
+ sample_token,
+ self.gt_boxes if self.eval_set != 'test' else EvalBoxes(),
+ # Don't render test GT.
+ self.pred_boxes,
+ eval_range=max(self.cfg.class_range.values()),
+ savepath=os.path.join(example_dir, '{}.png'.format(sample_token)))
+
+ # Run evaluation.
+ metrics, metric_data_list = self.evaluate()
+
+ return metrics
+
+class NuScenesEval(MotionEval):
+ """
+ Dummy class for backward-compatibility. Same as MotionEval.
+ """
+
+
+class OccludedMotionEval(MotionEval):
+ """Evaluates motion prediction specifically for occluded objects (num_lidar_pts == 0).
+
+ Loads GT filtered to only boxes with zero LiDAR points and applies class +
+ distance filtering without the standard num_pts >= 1 gate.
+ """
+
+ def __init__(self,
+ nusc: NuScenes,
+ config: DetectionConfig,
+ result_path: str,
+ eval_set: str,
+ output_dir: str = None,
+ verbose: bool = True,
+ seconds: int = 12):
+ self.nusc = nusc
+ self.result_path = result_path
+ self.eval_set = eval_set
+ self.output_dir = output_dir
+ self.verbose = verbose
+ self.cfg = config
+
+ # Make dirs.
+ self.plot_dir = os.path.join(self.output_dir, 'plots')
+ if not os.path.isdir(self.output_dir):
+ os.makedirs(self.output_dir)
+ if not os.path.isdir(self.plot_dir):
+ os.makedirs(self.plot_dir)
+
+ # Load data.
+ if verbose:
+ print('Initializing occluded motion evaluation')
+ self.pred_boxes, self.meta = load_prediction(
+ self.result_path, self.cfg.max_boxes_per_sample, MotionBox, verbose=verbose
+ )
+ # Load GT restricted to boxes with zero lidar points.
+ self.gt_boxes = load_gt(
+ self.nusc, self.eval_set, MotionBox, verbose=verbose,
+ seconds=seconds, occluded_only=True
+ )
+
+ assert set(self.pred_boxes.sample_tokens) == set(self.gt_boxes.sample_tokens), \
+ "Samples in split doesn't match samples in predictions."
+
+ # Add center distances.
+ self.pred_boxes = add_center_dist(nusc, self.pred_boxes)
+ self.gt_boxes = add_center_dist(nusc, self.gt_boxes)
+
+ # Filter predictions normally (class + distance + score).
+ if verbose:
+ print('Filtering predictions')
+ self.pred_boxes = filter_eval_boxes(nusc, self.pred_boxes, self.cfg.class_range, verbose=verbose)
+
+ # For GT: apply only class + distance filter; skip num_pts check because
+ # every occluded box has num_pts == 0 and would be removed by the standard filter.
+ if verbose:
+ print('Filtering occluded ground truth (class + distance only)')
+ self.gt_boxes = self._filter_occluded_gt(self.gt_boxes)
+
+ self.sample_tokens = self.gt_boxes.sample_tokens
+
+ def _filter_occluded_gt(self, gt_boxes):
+ """Return GT boxes restricted to known classes within their distance range."""
+ from collections import Counter
+ from nuscenes.eval.common.data_classes import EvalBoxes as _EvalBoxes
+ filtered = _EvalBoxes()
+ for sample_token in gt_boxes.sample_tokens:
+ boxes = [
+ box for box in gt_boxes[sample_token]
+ if box.detection_name in self.cfg.class_range
+ and box.ego_dist < self.cfg.class_range[box.detection_name]
+ ]
+ filtered.add_boxes(sample_token, boxes)
+ total = sum(len(filtered[t]) for t in filtered.sample_tokens)
+ class_counts = Counter(
+ box.detection_name
+ for t in filtered.sample_tokens
+ for box in filtered[t]
+ )
+ print(f'[Occluded Motion] GT occluded boxes: {total} | {dict(class_counts)}')
+ return filtered
+
+
+class AllMotionEval(OccludedMotionEval):
+ """Evaluates motion prediction on all objects (visible + occluded, num_pts >= 0).
+
+ Identical to OccludedMotionEval but loads GT without the occluded_only
+ restriction, so predictions are scored against every annotated object.
+ """
+
+ def __init__(self,
+ nusc: NuScenes,
+ config: DetectionConfig,
+ result_path: str,
+ eval_set: str,
+ output_dir: str = None,
+ verbose: bool = True,
+ seconds: int = 12):
+ self.nusc = nusc
+ self.result_path = result_path
+ self.eval_set = eval_set
+ self.output_dir = output_dir
+ self.verbose = verbose
+ self.cfg = config
+
+ self.plot_dir = os.path.join(self.output_dir, 'plots')
+ if not os.path.isdir(self.output_dir):
+ os.makedirs(self.output_dir)
+ if not os.path.isdir(self.plot_dir):
+ os.makedirs(self.plot_dir)
+
+ if verbose:
+ print('Initializing all-objects motion evaluation')
+ self.pred_boxes, self.meta = load_prediction(
+ self.result_path, self.cfg.max_boxes_per_sample, MotionBox, verbose=verbose
+ )
+ # Load GT for all objects (visible + occluded).
+ self.gt_boxes = load_gt(
+ self.nusc, self.eval_set, MotionBox, verbose=verbose,
+ seconds=seconds
+ )
+
+ assert set(self.pred_boxes.sample_tokens) == set(self.gt_boxes.sample_tokens), \
+ "Samples in split doesn't match samples in predictions."
+
+ self.pred_boxes = add_center_dist(nusc, self.pred_boxes)
+ self.gt_boxes = add_center_dist(nusc, self.gt_boxes)
+
+ if verbose:
+ print('Filtering predictions')
+ self.pred_boxes = filter_eval_boxes(nusc, self.pred_boxes, self.cfg.class_range, verbose=verbose)
+
+ if verbose:
+ print('Filtering all ground truth (class + distance only)')
+ self.gt_boxes = self._filter_all_gt(self.gt_boxes)
+
+ self.sample_tokens = self.gt_boxes.sample_tokens
+
+ def _filter_all_gt(self, gt_boxes):
+ """Return GT boxes for all objects within their distance range (no num_pts filter)."""
+ from collections import Counter
+ from nuscenes.eval.common.data_classes import EvalBoxes as _EvalBoxes
+ filtered = _EvalBoxes()
+ for sample_token in gt_boxes.sample_tokens:
+ boxes = [
+ box for box in gt_boxes[sample_token]
+ if box.detection_name in self.cfg.class_range
+ and box.ego_dist < self.cfg.class_range[box.detection_name]
+ ]
+ filtered.add_boxes(sample_token, boxes)
+ total = sum(len(filtered[t]) for t in filtered.sample_tokens)
+ class_counts = Counter(
+ box.detection_name
+ for t in filtered.sample_tokens
+ for box in filtered[t]
+ )
+ print(f'[All Motion] GT all boxes: {total} | {dict(class_counts)}')
+ return filtered
+
+
+if __name__ == "__main__":
+
+ # Settings.
+ parser = argparse.ArgumentParser(description='Evaluate nuScenes detection results.',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument('result_path', type=str, help='The submission as a JSON file.')
+ parser.add_argument('--output_dir', type=str, default='~/nuscenes-metrics',
+ help='Folder to store result metrics, graphs and example visualizations.')
+ parser.add_argument('--eval_set', type=str, default='val',
+ help='Which dataset split to evaluate on, train, val or test.')
+ parser.add_argument('--dataroot', type=str, default='/data/sets/nuscenes',
+ help='Default nuScenes data directory.')
+ parser.add_argument('--version', type=str, default='v1.0-trainval',
+ help='Which version of the nuScenes dataset to evaluate on, e.g. v1.0-trainval.')
+ parser.add_argument('--config_path', type=str, default='',
+ help='Path to the configuration file.'
+ 'If no path given, the CVPR 2019 configuration will be used.')
+ parser.add_argument('--plot_examples', type=int, default=10,
+ help='How many example visualizations to write to disk.')
+ parser.add_argument('--render_curves', type=int, default=1,
+ help='Whether to render PR and TP curves to disk.')
+ parser.add_argument('--verbose', type=int, default=1,
+ help='Whether to print to stdout.')
+ args = parser.parse_args()
+
+ result_path_ = os.path.expanduser(args.result_path)
+ output_dir_ = os.path.expanduser(args.output_dir)
+ eval_set_ = args.eval_set
+ dataroot_ = args.dataroot
+ version_ = args.version
+ config_path = args.config_path
+ plot_examples_ = args.plot_examples
+ render_curves_ = bool(args.render_curves)
+ verbose_ = bool(args.verbose)
+
+ if config_path == '':
+ cfg_ = config_factory('detection_cvpr_2019')
+ else:
+ with open(config_path, 'r') as _f:
+ cfg_ = DetectionConfig.deserialize(json.load(_f))
+
+ nusc_ = NuScenes(version=version_, verbose=verbose_, dataroot=dataroot_)
+ nusc_eval = DetectionEval(nusc_, config=cfg_, result_path=result_path_, eval_set=eval_set_,
+ output_dir=output_dir_, verbose=verbose_)
+ nusc_eval.main(plot_examples=plot_examples_, render_curves=render_curves_)
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_utils.py b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_utils.py
new file mode 100644
index 0000000..2a522b4
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/evaluation/motion/motion_utils.py
@@ -0,0 +1,680 @@
+# nuScenes dev-kit.
+# Code written by Holger Caesar & Oscar Beijbom, 2018.
+
+import argparse
+import json
+import os
+import random
+import time
+import tqdm
+from typing import Tuple, Dict, Any, Callable
+
+import numpy as np
+
+from nuscenes import NuScenes
+from nuscenes.eval.common.config import config_factory
+from nuscenes.eval.common.data_classes import EvalBoxes
+from nuscenes.eval.common.loaders import add_center_dist, filter_eval_boxes
+from nuscenes.eval.detection.algo import calc_ap, calc_tp
+from nuscenes.eval.detection.constants import DETECTION_NAMES, ATTRIBUTE_NAMES, TP_METRICS
+from nuscenes.eval.detection.data_classes import DetectionConfig, DetectionMetrics, DetectionBox, \
+ DetectionMetricDataList, DetectionMetricData
+from nuscenes.eval.detection.render import summary_plot, class_pr_curve, class_tp_curve, dist_pr_curve, visualize_sample
+from nuscenes.prediction import PredictHelper, convert_local_coords_to_global
+from nuscenes.utils.splits import create_splits_scenes
+from nuscenes.eval.detection.utils import category_to_detection_name
+from nuscenes.eval.common.utils import quaternion_yaw, Quaternion
+from nuscenes.eval.common.utils import center_distance, scale_iou, yaw_diff, velocity_l2, attr_acc, cummean
+
+
+motion_name_mapping = {
+ 'car': 'car',
+ 'truck': 'car',
+ 'construction_vehicle': 'car',
+ 'bus': 'car',
+ 'trailer': 'car',
+ 'motorcycle': 'car',
+ 'bicycle': 'car',
+ 'pedestrian': 'pedestrian',
+ 'traffic_cone': 'barrier',
+ 'barrier': 'barrier',
+}
+
+
+class MotionBox(DetectionBox):
+ """ Data class used during detection evaluation. Can be a prediction or ground truth."""
+
+ def __init__(self,
+ sample_token: str = "",
+ translation: Tuple[float, float, float] = (0, 0, 0),
+ size: Tuple[float, float, float] = (0, 0, 0),
+ rotation: Tuple[float, float, float, float] = (0, 0, 0, 0),
+ velocity: Tuple[float, float] = (0, 0),
+ ego_translation: [float, float, float] = (0, 0, 0), # Translation to ego vehicle in meters.
+ num_pts: int = -1, # Nbr. LIDAR or RADAR inside the box. Only for gt boxes.
+ detection_name: str = 'car', # The class name used in the detection challenge.
+ detection_score: float = -1.0, # GT samples do not have a score.
+ attribute_name: str = '', # Box attribute. Each box can have at most 1 attribute.
+ traj=None,
+ traj_score=None):
+
+ super().__init__(sample_token, translation, size, rotation, velocity, ego_translation, num_pts)
+
+ assert detection_name is not None, 'Error: detection_name cannot be empty!'
+ assert detection_name in DETECTION_NAMES, 'Error: Unknown detection_name %s' % detection_name
+
+ assert attribute_name in ATTRIBUTE_NAMES or attribute_name == '', \
+ 'Error: Unknown attribute_name %s' % attribute_name
+
+ assert type(detection_score) == float, 'Error: detection_score must be a float!'
+ assert not np.any(np.isnan(detection_score)), 'Error: detection_score may not be NaN!'
+
+ # Assign.
+ self.detection_name = detection_name
+ self.detection_score = detection_score
+ self.attribute_name = attribute_name
+ self.traj = traj
+ self.traj_score = traj_score
+
+ def __eq__(self, other):
+ return (self.sample_token == other.sample_token and
+ self.translation == other.translation and
+ self.size == other.size and
+ self.rotation == other.rotation and
+ self.velocity == other.velocity and
+ self.ego_translation == other.ego_translation and
+ self.num_pts == other.num_pts and
+ self.detection_name == other.detection_name and
+ self.detection_score == other.detection_score and
+ self.attribute_name == other.attribute_name and
+ np.all(self.traj == other.traj))
+
+ def serialize(self) -> dict:
+ """ Serialize instance into json-friendly format. """
+ return {
+ 'sample_token': self.sample_token,
+ 'translation': self.translation,
+ 'size': self.size,
+ 'rotation': self.rotation,
+ 'velocity': self.velocity,
+ 'ego_translation': self.ego_translation,
+ 'num_pts': self.num_pts,
+ 'detection_name': self.detection_name,
+ 'detection_score': self.detection_score,
+ 'attribute_name': self.attribute_name,
+ 'traj': self.traj,
+ 'traj_score': self.traj_score,
+ }
+
+ @classmethod
+ def deserialize(cls, content: dict):
+ """ Initialize from serialized content. """
+ return cls(sample_token=content['sample_token'],
+ translation=tuple(content['translation']),
+ size=tuple(content['size']),
+ rotation=tuple(content['rotation']),
+ velocity=tuple(content['velocity']),
+ ego_translation=(0.0, 0.0, 0.0) if 'ego_translation' not in content
+ else tuple(content['ego_translation']),
+ num_pts=-1 if 'num_pts' not in content else int(content['num_pts']),
+ detection_name=content['detection_name'],
+ detection_score=-1.0 if 'detection_score' not in content else float(content['detection_score']),
+ attribute_name=content['attribute_name'],
+ traj=content['trajs'],
+ traj_score=content.get('trajs_score', None),)
+
+
+def load_prediction(result_path: str, max_boxes_per_sample: int, box_cls, verbose: bool = False) \
+ -> Tuple[EvalBoxes, Dict]:
+ """
+ Loads object predictions from file.
+ :param result_path: Path to the .json result file provided by the user.
+ :param max_boxes_per_sample: Maximim number of boxes allowed per sample.
+ :param box_cls: Type of box to load, e.g. DetectionBox or TrackingBox.
+ :param verbose: Whether to print messages to stdout.
+ :return: The deserialized results and meta data.
+ """
+
+ # Load from file and check that the format is correct.
+ # with open(result_path) as f:
+ # data = json.load(f)
+ data = result_path
+ assert 'results' in data, 'Error: No field `results` in result file. Please note that the result format changed.' \
+ 'See https://www.nuscenes.org/object-detection for more information.'
+
+ # motion name mapping
+ for key in data['results'].keys():
+ for i in range(len(data['results'][key])):
+ cls_name = data['results'][key][i]['detection_name']
+ if cls_name in motion_name_mapping:
+ cls_name = motion_name_mapping[cls_name]
+ data['results'][key][i]['detection_name'] = cls_name
+
+ # Deserialize results and get meta data.
+ all_results = EvalBoxes.deserialize(data['results'], box_cls)
+ meta = data['meta']
+ if verbose:
+ print("Loaded results from {}. Found detections for {} samples."
+ .format(result_path, len(all_results.sample_tokens)))
+
+ # Check that each sample has no more than x predicted boxes.
+ for sample_token in all_results.sample_tokens:
+ assert len(all_results.boxes[sample_token]) <= max_boxes_per_sample, \
+ "Error: Only <= %d boxes per sample allowed!" % max_boxes_per_sample
+
+ return all_results, meta
+
+
+def load_gt(nusc: NuScenes, eval_split: str, box_cls, verbose: bool = False, seconds: int = 12, occluded_only: bool = False) -> EvalBoxes:
+ """
+ Loads ground truth boxes from DB.
+ :param nusc: A NuScenes instance.
+ :param eval_split: The evaluation split for which we load GT boxes.
+ :param box_cls: Type of box to load, e.g. DetectionBox or TrackingBox.
+ :param verbose: Whether to print messages to stdout.
+ :return: The GT boxes.
+ """
+ predict_helper = PredictHelper(nusc)
+ # Init.
+ if box_cls == MotionBox:
+ attribute_map = {a['token']: a['name'] for a in nusc.attribute}
+
+ if verbose:
+ print('Loading annotations for {} split from nuScenes version: {}'.format(eval_split, nusc.version))
+ # Read out all sample_tokens in DB.
+ sample_tokens_all = [s['token'] for s in nusc.sample]
+ assert len(sample_tokens_all) > 0, "Error: Database has no samples!"
+
+ # Only keep samples from this split.
+ splits = create_splits_scenes()
+
+ # Check compatibility of split with nusc_version.
+ version = nusc.version
+ if eval_split in {'train', 'val', 'train_detect', 'train_track'}:
+ assert version.endswith('trainval'), \
+ 'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
+ elif eval_split in {'mini_train', 'mini_val'}:
+ assert version.endswith('mini'), \
+ 'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
+ elif eval_split == 'test':
+ assert version.endswith('test'), \
+ 'Error: Requested split {} which is not compatible with NuScenes version {}'.format(eval_split, version)
+ else:
+ raise ValueError('Error: Requested split {} which this function cannot map to the correct NuScenes version.'
+ .format(eval_split))
+
+ if eval_split == 'test':
+ # Check that you aren't trying to cheat :).
+ assert len(nusc.sample_annotation) > 0, \
+ 'Error: You are trying to evaluate on the test set but you do not have the annotations!'
+
+ sample_tokens = []
+ for sample_token in sample_tokens_all:
+ scene_token = nusc.get('sample', sample_token)['scene_token']
+ scene_record = nusc.get('scene', scene_token)
+ if scene_record['name'] in splits[eval_split]:
+ sample_tokens.append(sample_token)
+
+ all_annotations = EvalBoxes()
+
+ # Load annotations and filter predictions and annotations.
+ tracking_id_set = set()
+ for sample_token in tqdm.tqdm(sample_tokens, leave=verbose):
+
+ sample = nusc.get('sample', sample_token)
+ sample_annotation_tokens = sample['anns']
+
+ sample_boxes = []
+ for sample_annotation_token in sample_annotation_tokens:
+
+ sample_annotation = nusc.get('sample_annotation', sample_annotation_token)
+ if box_cls == MotionBox:
+ # Get label name in detection task and filter unused labels.
+ detection_name = category_to_detection_name(sample_annotation['category_name'])
+ # motion name mapping
+ if detection_name in motion_name_mapping:
+ detection_name = motion_name_mapping[detection_name]
+
+ if detection_name is None:
+ continue
+
+ # Get attribute_name.
+ attr_tokens = sample_annotation['attribute_tokens']
+ attr_count = len(attr_tokens)
+ if attr_count == 0:
+ attribute_name = ''
+ elif attr_count == 1:
+ attribute_name = attribute_map[attr_tokens[0]]
+ else:
+ raise Exception('Error: GT annotations must not have more than one attribute!')
+
+ # get future trajs
+ instance_token = nusc.get('sample_annotation', sample_annotation['token'])['instance_token']
+ fut_traj_local = predict_helper.get_future_for_agent(
+ instance_token,
+ sample_token,
+ seconds=seconds,
+ in_agent_frame=True
+ )
+ if fut_traj_local.shape[0] > 0:
+ _, boxes, _ = nusc.get_sample_data(sample['data']['LIDAR_TOP'], selected_anntokens=[sample_annotation['token']])
+ box = boxes[0]
+ trans = box.center
+ rot = Quaternion(matrix=box.rotation_matrix)
+ fut_traj_scence_centric = convert_local_coords_to_global(fut_traj_local, trans, rot)
+ else:
+ fut_traj_scence_centric = np.zeros((0,))
+
+ num_pts = sample_annotation['num_lidar_pts'] + sample_annotation['num_radar_pts']
+ if occluded_only and num_pts > 0:
+ continue
+ sample_boxes.append(
+ box_cls(
+ sample_token=sample_token,
+ translation=sample_annotation['translation'],
+ size=sample_annotation['size'],
+ rotation=sample_annotation['rotation'],
+ velocity=nusc.box_velocity(sample_annotation['token'])[:2],
+ num_pts=num_pts,
+ detection_name=detection_name,
+ detection_score=-1.0, # GT samples do not have a score.
+ attribute_name=attribute_name,
+ traj=fut_traj_scence_centric
+ )
+ )
+ elif box_cls == TrackingBox:
+ # Use nuScenes token as tracking id.
+ tracking_id = sample_annotation['instance_token']
+ tracking_id_set.add(tracking_id)
+
+ # Get label name in detection task and filter unused labels.
+ # Import locally to avoid errors when motmetrics package is not installed.
+ from nuscenes.eval.tracking.utils import category_to_tracking_name
+ tracking_name = category_to_tracking_name(sample_annotation['category_name'])
+ if tracking_name is None:
+ continue
+
+ sample_boxes.append(
+ box_cls(
+ sample_token=sample_token,
+ translation=sample_annotation['translation'],
+ size=sample_annotation['size'],
+ rotation=sample_annotation['rotation'],
+ velocity=nusc.box_velocity(sample_annotation['token'])[:2],
+ num_pts=sample_annotation['num_lidar_pts'] + sample_annotation['num_radar_pts'],
+ tracking_id=tracking_id,
+ tracking_name=tracking_name,
+ tracking_score=-1.0 # GT samples do not have a score.
+ )
+ )
+ else:
+ raise NotImplementedError('Error: Invalid box_cls %s!' % box_cls)
+
+ all_annotations.add_boxes(sample_token, sample_boxes)
+
+ if verbose:
+ print("Loaded ground truth annotations for {} samples.".format(len(all_annotations.sample_tokens)))
+
+ return all_annotations
+
+
+def accumulate(gt_boxes: EvalBoxes,
+ pred_boxes: EvalBoxes,
+ class_name: str,
+ dist_fcn: Callable,
+ dist_th: float,
+ verbose: bool = False) -> DetectionMetricData:
+ """
+ Average Precision over predefined different recall thresholds for a single distance threshold.
+ The recall/conf thresholds and other raw metrics will be used in secondary metrics.
+ :param gt_boxes: Maps every sample_token to a list of its sample_annotations.
+ :param pred_boxes: Maps every sample_token to a list of its sample_results.
+ :param class_name: Class to compute AP on.
+ :param dist_fcn: Distance function used to match detections and ground truths.
+ :param dist_th: Distance threshold for a match.
+ :param verbose: If true, print debug messages.
+ :return: (average_prec, metrics). The average precision value and raw data for a number of metrics.
+ """
+ # ---------------------------------------------
+ # Organize input and initialize accumulators.
+ # ---------------------------------------------
+
+ # Count the positives.
+ npos = len([1 for gt_box in gt_boxes.all if gt_box.detection_name == class_name])
+ if verbose:
+ print("Found {} GT of class {} out of {} total across {} samples.".
+ format(npos, class_name, len(gt_boxes.all), len(gt_boxes.sample_tokens)))
+
+ # For missing classes in the GT, return a data structure corresponding to no predictions.
+ if npos == 0:
+ return DetectionMetricData.no_predictions(), 0
+
+ # Organize the predictions in a single list.
+ pred_boxes_list = [box for box in pred_boxes.all if box.detection_name == class_name]
+ pred_confs = [box.detection_score for box in pred_boxes_list]
+
+ if verbose:
+ print("Found {} PRED of class {} out of {} total across {} samples.".
+ format(len(pred_confs), class_name, len(pred_boxes.all), len(pred_boxes.sample_tokens)))
+
+ # Sort by confidence.
+ sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(pred_confs))][::-1]
+
+ # Do the actual matching.
+ tp = [] # Accumulator of true positives.
+ fp = [] # Accumulator of false positives.
+ conf = [] # Accumulator of confidences.
+ hit = 0 # Accumulator of matched and hit
+
+ # match_data holds the extra metrics we calculate for each match.
+ match_data = {'conf': [],
+ 'min_ade': [],
+ 'min_fde': [],
+ 'miss_rate': [],
+ 'top1_fde': [],
+ 'brier_min_fde': []}
+
+ # ---------------------------------------------
+ # Match and accumulate match data.
+ # ---------------------------------------------
+
+ taken = set() # Initially no gt bounding box is matched.
+ for ind in sortind:
+ pred_box = pred_boxes_list[ind]
+ min_dist = np.inf
+ match_gt_idx = None
+
+ for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]):
+
+ # Find closest match among ground truth boxes
+ if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken:
+ this_distance = dist_fcn(gt_box, pred_box)
+ if this_distance < min_dist:
+ min_dist = this_distance
+ match_gt_idx = gt_idx
+
+ # If the closest match is close enough according to threshold we have a match!
+ is_match = min_dist < dist_th
+
+ if is_match:
+ taken.add((pred_box.sample_token, match_gt_idx))
+
+ # Update tp, fp and confs.
+ tp.append(1)
+ fp.append(0)
+ conf.append(pred_box.detection_score)
+
+ # Since it is a match, update match data also.
+ gt_box_match = gt_boxes[pred_box.sample_token][match_gt_idx]
+
+ match_data['conf'].append(pred_box.detection_score)
+
+ minade, minfde, mr, top1_fde, brier_min_fde = prediction_metrics(gt_box_match, pred_box)
+ match_data['min_ade'].append(minade)
+ match_data['min_fde'].append(minfde)
+ match_data['miss_rate'].append(mr)
+ match_data['top1_fde'].append(top1_fde)
+ match_data['brier_min_fde'].append(brier_min_fde)
+
+ if minfde < 2.0:
+ hit += 1
+
+ else:
+ # No match. Mark this as a false positive.
+ tp.append(0)
+ fp.append(1)
+ conf.append(pred_box.detection_score)
+
+ # Check if we have any matches. If not, just return a "no predictions" array.
+ if len(match_data['min_ade']) == 0:
+ return MotionMetricData.no_predictions()
+
+ # Accumulate.
+ N_tp = np.sum(tp)
+ N_fp = np.sum(fp)
+ tp = np.cumsum(tp).astype(float)
+ fp = np.cumsum(fp).astype(float)
+ conf = np.array(conf)
+
+ # Calculate precision and recall.
+ prec = tp / (fp + tp)
+ rec = tp / float(npos)
+
+ rec_interp = np.linspace(0, 1, DetectionMetricData.nelem) # 101 steps, from 0% to 100% recall.
+ prec = np.interp(rec_interp, rec, prec, right=0)
+ conf = np.interp(rec_interp, rec, conf, right=0)
+ rec = rec_interp
+
+ # ---------------------------------------------
+ # Re-sample the match-data to match, prec, recall and conf.
+ # ---------------------------------------------
+ confs = np.array(match_data['conf'])
+ all_same_conf = confs.max() == confs.min()
+
+ for key in match_data.keys():
+ if key == "conf":
+ continue # Confidence is used as reference to align with fp and tp. So skip in this step.
+
+ arr = np.array(match_data[key])
+ if all_same_conf:
+ # When all predictions have identical confidence (e.g. a GT-oracle model),
+ # the recall-weighted np.interp degenerates because xp is constant and
+ # np.interp requires a strictly-increasing sequence. Fall back to the
+ # plain mean replicated across all recall levels.
+ match_data[key] = np.full(DetectionMetricData.nelem, arr.mean())
+ else:
+ # For each match_data, we first calculate the accumulated mean.
+ tmp = cummean(arr)
+ # Then interpolate based on the confidences. (Note reversing since np.interp needs increasing arrays)
+ match_data[key] = np.interp(conf[::-1], confs[::-1], tmp[::-1])[::-1]
+ EPA = (hit - 0.5 * N_fp) / npos
+
+ ## match based on traj
+ traj_matched = 0
+ taken = set() # Initially no gt bounding box is matched.
+ for ind in sortind:
+ pred_box = pred_boxes_list[ind]
+ min_dist = np.inf
+ match_gt_idx = None
+
+ for gt_idx, gt_box in enumerate(gt_boxes[pred_box.sample_token]):
+
+ # Find closest match among ground truth boxes
+ if gt_box.detection_name == class_name and not (pred_box.sample_token, gt_idx) in taken:
+ this_distance = dist_fcn(gt_box, pred_box)
+ if this_distance < min_dist:
+ min_dist = this_distance
+ match_gt_idx = gt_idx
+ fde_distance = traj_fde(gt_box, pred_box, final_step=12)
+
+ # If the closest match is close enough according to threshold we have a match!
+ is_match = min_dist < dist_th and fde_distance < 2.0
+ if is_match:
+ taken.add((pred_box.sample_token, match_gt_idx))
+ traj_matched += 1
+ EPA_ = (traj_matched - 0.5 * N_fp) / npos ## same as UniAD
+
+ # ---------------------------------------------
+ # Done. Instantiate MetricData and return
+ # ---------------------------------------------
+ return MotionMetricData(recall=rec,
+ precision=prec,
+ confidence=conf,
+ min_ade_err=match_data['min_ade'],
+ min_fde_err=match_data['min_fde'],
+ miss_rate_err=match_data['miss_rate'],
+ top1_fde_err=match_data['top1_fde'],
+ brier_min_fde_err=match_data['brier_min_fde']), EPA, EPA_
+
+
+def prediction_metrics(gt_box_match, pred_box, miss_thresh=2):
+ gt_traj = np.array(gt_box_match.traj)
+ pred_traj = np.array(pred_box.traj)
+
+ valid_step = gt_traj.shape[0]
+ if valid_step <= 0:
+ return 0, 0, 0, 0, 0
+
+ pred_traj_valid = pred_traj[:, :valid_step, :]
+ dist = np.linalg.norm(pred_traj_valid - gt_traj[np.newaxis], axis=2)
+
+ minade = dist.mean(axis=1).min()
+ minfde = dist[:, -1].min()
+ mr = dist.max(axis=1).min() > miss_thresh
+
+ # Top-1 FDE: FDE of the highest-confidence mode.
+ # Brier-minFDE: minFDE + (1 - p_best)^2, where p_best is the normalized
+ # probability assigned to the mode closest to GT (nuScenes leaderboard metric).
+ traj_score = getattr(pred_box, 'traj_score', None)
+ if traj_score is not None and len(traj_score) == pred_traj.shape[0]:
+ scores = np.array(traj_score, dtype=np.float64)
+ top1_idx = int(np.argmax(scores))
+ top1_fde = float(dist[top1_idx, -1])
+
+ scores_sum = scores.sum()
+ probs = scores / scores_sum if scores_sum > 1e-6 else np.ones(len(scores)) / len(scores)
+ best_mode_idx = int(np.argmin(dist[:, -1]))
+ p_best = float(probs[best_mode_idx])
+ brier_min_fde = minfde + (1.0 - p_best) ** 2
+ else:
+ # No per-mode scores available: fall back to min-FDE for top1,
+ # and worst-case confidence penalty for Brier-minFDE.
+ top1_fde = minfde
+ brier_min_fde = minfde + 1.0
+
+ return minade, minfde, mr, top1_fde, brier_min_fde
+
+def traj_fde(gt_box, pred_box, final_step):
+ if gt_box.traj.shape[0] <= 0:
+ return np.inf
+ final_step = min(gt_box.traj.shape[0], final_step)
+ gt_final = gt_box.traj[None, final_step-1]
+ pred_final = np.array(pred_box.traj)[:,final_step-1,:]
+ err = gt_final - pred_final
+ err = np.sqrt(np.sum(np.square(gt_final - pred_final), axis=-1))
+ return np.min(err)
+
+
+class MotionMetricDataList(DetectionMetricDataList):
+ """ This stores a set of MetricData in a dict indexed by (name, match-distance). """
+ @classmethod
+ def deserialize(cls, content: dict):
+ mdl = cls()
+ for key, md in content.items():
+ name, distance = key.split(':')
+ mdl.set(name, float(distance), MotionMetricData.deserialize(md))
+ return mdl
+
+class MotionMetricData(DetectionMetricData):
+ """ This class holds accumulated and interpolated data required to calculate the detection metrics. """
+
+ nelem = 101
+
+ def __init__(self,
+ recall: np.array,
+ precision: np.array,
+ confidence: np.array,
+ min_ade_err: np.array,
+ min_fde_err: np.array,
+ miss_rate_err: np.array,
+ top1_fde_err: np.array,
+ brier_min_fde_err: np.array):
+
+ # Assert lengths.
+ assert len(recall) == self.nelem
+ assert len(precision) == self.nelem
+ assert len(confidence) == self.nelem
+ assert len(min_ade_err) == self.nelem
+ assert len(min_fde_err) == self.nelem
+ assert len(miss_rate_err) == self.nelem
+ assert len(top1_fde_err) == self.nelem
+ assert len(brier_min_fde_err) == self.nelem
+
+ # Assert ordering.
+ assert all(confidence == sorted(confidence, reverse=True)) # Confidences should be descending.
+ assert all(recall == sorted(recall)) # Recalls should be ascending.
+
+ # Set attributes explicitly to help IDEs figure out what is going on.
+ self.recall = recall
+ self.precision = precision
+ self.confidence = confidence
+ self.min_ade_err = min_ade_err
+ self.min_fde_err = min_fde_err
+ self.miss_rate_err = miss_rate_err
+ self.top1_fde_err = top1_fde_err
+ self.brier_min_fde_err = brier_min_fde_err
+
+ def __eq__(self, other):
+ eq = True
+ for key in self.serialize().keys():
+ eq = eq and np.array_equal(getattr(self, key), getattr(other, key))
+ return eq
+
+ @property
+ def max_recall_ind(self):
+ """ Returns index of max recall achieved. """
+
+ # Last instance of confidence > 0 is index of max achieved recall.
+ non_zero = np.nonzero(self.confidence)[0]
+ if len(non_zero) == 0: # If there are no matches, all the confidence values will be zero.
+ max_recall_ind = 0
+ else:
+ max_recall_ind = non_zero[-1]
+
+ return max_recall_ind
+
+ @property
+ def max_recall(self):
+ """ Returns max recall achieved. """
+
+ return self.recall[self.max_recall_ind]
+
+ def serialize(self):
+ """ Serialize instance into json-friendly format. """
+ return {
+ 'recall': self.recall.tolist(),
+ 'precision': self.precision.tolist(),
+ 'confidence': self.confidence.tolist(),
+ 'min_ade_err': self.min_ade_err.tolist(),
+ 'min_fde_err': self.min_fde_err.tolist(),
+ 'miss_rate_err': self.miss_rate_err.tolist(),
+ 'top1_fde_err': self.top1_fde_err.tolist(),
+ 'brier_min_fde_err': self.brier_min_fde_err.tolist(),
+ }
+
+ @classmethod
+ def deserialize(cls, content: dict):
+ """ Initialize from serialized content. """
+ return cls(recall=np.array(content['recall']),
+ precision=np.array(content['precision']),
+ confidence=np.array(content['confidence']),
+ min_ade_err=np.array(content['min_ade_err']),
+ min_fde_err=np.array(content['min_fde_err']),
+ miss_rate_err=np.array(content['miss_rate_err']),
+ top1_fde_err=np.array(content['top1_fde_err']),
+ brier_min_fde_err=np.array(content['brier_min_fde_err']))
+
+ @classmethod
+ def no_predictions(cls):
+ """ Returns a md instance corresponding to having no predictions. """
+ return cls(recall=np.linspace(0, 1, cls.nelem),
+ precision=np.zeros(cls.nelem),
+ confidence=np.zeros(cls.nelem),
+ min_ade_err=np.ones(cls.nelem),
+ min_fde_err=np.ones(cls.nelem),
+ miss_rate_err=np.ones(cls.nelem),
+ top1_fde_err=np.ones(cls.nelem),
+ brier_min_fde_err=np.ones(cls.nelem) * 2.0)
+
+ @classmethod
+ def random_md(cls):
+ """ Returns an md instance corresponding to a random results. """
+ return cls(recall=np.linspace(0, 1, cls.nelem),
+ precision=np.random.random(cls.nelem),
+ confidence=np.linspace(0, 1, cls.nelem)[::-1],
+ min_ade_err=np.random.random(cls.nelem),
+ min_fde_err=np.random.random(cls.nelem),
+ miss_rate_err=np.random.random(cls.nelem),
+ top1_fde_err=np.random.random(cls.nelem),
+ brier_min_fde_err=np.random.random(cls.nelem))
+
diff --git a/projects/mmdet3d_plugin/datasets/evaluation/planning/planning_eval.py b/projects/mmdet3d_plugin/datasets/evaluation/planning/planning_eval.py
new file mode 100644
index 0000000..daeb457
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/evaluation/planning/planning_eval.py
@@ -0,0 +1,209 @@
+from tqdm import tqdm
+import torch
+import torch.nn as nn
+import numpy as np
+from shapely.geometry import Polygon
+
+from mmcv.utils import print_log
+from mmdet.datasets import build_dataset, build_dataloader
+
+from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners
+
+
+def check_collision(ego_box, boxes):
+ '''
+ ego_box: tensor with shape [7], [x, y, z, w, l, h, yaw]
+ boxes: tensor with shape [N, 7]
+ '''
+ if boxes.shape[0] == 0:
+ return False
+
+ # follow uniad, add a 0.5m offset
+ ego_box[0] += 0.5 * torch.cos(ego_box[6])
+ ego_box[1] += 0.5 * torch.sin(ego_box[6])
+ ego_corners_box = box3d_to_corners(ego_box.unsqueeze(0))[0, [0, 3, 7, 4], :2]
+ corners_box = box3d_to_corners(boxes)[:, [0, 3, 7, 4], :2]
+ ego_poly = Polygon([(point[0], point[1]) for point in ego_corners_box])
+ for i in range(len(corners_box)):
+ box_poly = Polygon([(point[0], point[1]) for point in corners_box[i]])
+ collision = ego_poly.intersects(box_poly)
+ if collision:
+ return True
+
+ return False
+
+def get_yaw(traj):
+ start = traj[0]
+ end = traj[-1]
+ dist = torch.linalg.norm(end - start, dim=-1)
+ if dist < 0.5:
+ return traj.new_ones(traj.shape[0]) * np.pi / 2
+
+ zeros = traj.new_zeros((1, 2))
+ traj_cat = torch.cat([zeros, traj], dim=0)
+ yaw = traj.new_zeros(traj.shape[0]+1)
+ yaw[..., 1:-1] = torch.atan2(
+ traj_cat[..., 2:, 1] - traj_cat[..., :-2, 1],
+ traj_cat[..., 2:, 0] - traj_cat[..., :-2, 0],
+ )
+ yaw[..., -1] = torch.atan2(
+ traj_cat[..., -1, 1] - traj_cat[..., -2, 1],
+ traj_cat[..., -1, 0] - traj_cat[..., -2, 0],
+ )
+ return yaw[1:]
+
+class PlanningMetric():
+ def __init__(
+ self,
+ n_future=6,
+ compute_on_step: bool = False,
+ ):
+ self.W = 1.85
+ self.H = 4.084
+
+ self.n_future = n_future
+ self.reset()
+
+ def reset(self):
+ self.obj_col = torch.zeros(self.n_future)
+ self.obj_box_col = torch.zeros(self.n_future)
+ self.obj_box_col_occluded = torch.zeros(self.n_future)
+ self.obj_box_col_all = torch.zeros(self.n_future)
+ self.L2 = torch.zeros(self.n_future)
+ self.total = torch.tensor(0)
+ self.total_occluded = torch.tensor(0)
+
+ def evaluate_single_coll(self, traj, fut_boxes):
+ n_future = traj.shape[0]
+ yaw = get_yaw(traj)
+ ego_box = traj.new_zeros((n_future, 7))
+ ego_box[:, :2] = traj
+ ego_box[:, 3:6] = ego_box.new_tensor([self.H, self.W, 1.56])
+ ego_box[:, 6] = yaw
+ collision = torch.zeros(n_future, dtype=torch.bool)
+
+ for t in range(n_future):
+ ego_box_t = ego_box[t].clone()
+ boxes = fut_boxes[t][0].clone()
+ collision[t] = check_collision(ego_box_t, boxes)
+ return collision
+
+ def evaluate_coll(self, trajs, gt_trajs, fut_boxes):
+ B, n_future, _ = trajs.shape
+ trajs = trajs * torch.tensor([-1, 1], device=trajs.device)
+ gt_trajs = gt_trajs * torch.tensor([-1, 1], device=gt_trajs.device)
+
+ obj_coll_sum = torch.zeros(n_future, device=trajs.device)
+ obj_box_coll_sum = torch.zeros(n_future, device=trajs.device)
+
+ assert B == 1, 'only supprt bs=1'
+ for i in range(B):
+ gt_box_coll = self.evaluate_single_coll(gt_trajs[i], fut_boxes)
+ box_coll = self.evaluate_single_coll(trajs[i], fut_boxes)
+ box_coll = torch.logical_and(box_coll, torch.logical_not(gt_box_coll))
+
+ obj_coll_sum += gt_box_coll.long()
+ obj_box_coll_sum += box_coll.long()
+
+ return obj_coll_sum, obj_box_coll_sum
+
+ def compute_L2(self, trajs, gt_trajs, gt_trajs_mask):
+ '''
+ trajs: torch.Tensor (B, n_future, 3)
+ gt_trajs: torch.Tensor (B, n_future, 3)
+ '''
+ return torch.sqrt((((trajs[:, :, :2] - gt_trajs[:, :, :2]) ** 2) * gt_trajs_mask).sum(dim=-1))
+
+ def update(self, trajs, gt_trajs, gt_trajs_mask, fut_boxes, fut_boxes_occluded=None):
+ assert trajs.shape == gt_trajs.shape
+ trajs[..., 0] = - trajs[..., 0]
+ gt_trajs[..., 0] = - gt_trajs[..., 0]
+ L2 = self.compute_L2(trajs, gt_trajs, gt_trajs_mask)
+ obj_coll_sum, obj_box_coll_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], fut_boxes)
+
+ self.obj_col += obj_coll_sum
+ self.obj_box_col += obj_box_coll_sum
+ self.L2 += L2.sum(dim=0)
+ self.total += len(trajs)
+
+ if fut_boxes_occluded is not None:
+ _, obj_box_coll_occ_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], fut_boxes_occluded)
+ self.obj_box_col_occluded += obj_box_coll_occ_sum
+ has_occluded = any(boxes[0].shape[0] > 0 for boxes in fut_boxes_occluded)
+ self.total_occluded += int(has_occluded)
+
+ merged_boxes = [
+ [torch.cat([fut_boxes[t][0], fut_boxes_occluded[t][0]], dim=0)]
+ for t in range(len(fut_boxes))
+ ]
+ _, obj_box_coll_all_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], merged_boxes)
+ self.obj_box_col_all += obj_box_coll_all_sum
+
+ def compute(self):
+ occ_denom = self.total_occluded if self.total_occluded > 0 else torch.tensor(1)
+ results = {
+ 'obj_col': self.obj_col / self.total,
+ 'obj_box_col': self.obj_box_col / self.total,
+ 'L2': self.L2 / self.total,
+ }
+ if self.total_occluded > 0:
+ results['occluded/obj_box_col'] = self.obj_box_col_occluded / occ_denom
+ results['all/obj_box_col'] = self.obj_box_col_all / self.total
+ return results
+
+
+def planning_eval(results, eval_config, logger, with_occlusion=False):
+ dataset = build_dataset(eval_config)
+ dataloader = build_dataloader(
+ dataset, samples_per_gpu=1, workers_per_gpu=1, shuffle=False, dist=False)
+ planning_metrics = PlanningMetric()
+ occluded_samples = 0
+ occluded_box_timesteps = 0
+ for i, data in enumerate(tqdm(dataloader)):
+ sdc_planning = data['gt_ego_fut_trajs'].cumsum(dim=-2).unsqueeze(1)
+ sdc_planning_mask = data['gt_ego_fut_masks'].unsqueeze(-1).repeat(1, 1, 2).unsqueeze(1)
+ command = data['gt_ego_fut_cmd'].argmax(dim=-1).item()
+ fut_boxes = data['fut_boxes']
+ fut_boxes_occluded = data.get('fut_boxes_occluded', None) if with_occlusion else None
+ if fut_boxes_occluded is not None:
+ sample_occ_count = sum(boxes[0].shape[0] for boxes in fut_boxes_occluded)
+ if sample_occ_count > 0:
+ occluded_samples += 1
+ occluded_box_timesteps += sample_occ_count
+ if not sdc_planning_mask.all(): ## for incomplete gt, we do not count this sample
+ continue
+ res = results[i]
+ pred_sdc_traj = res['img_bbox']['final_planning'].unsqueeze(0)
+ planning_metrics.update(pred_sdc_traj[:, :6, :2], sdc_planning[0,:, :6, :2], sdc_planning_mask[0,:, :6, :2], fut_boxes, fut_boxes_occluded)
+ if with_occlusion:
+ print(f'[Occluded Planning] Samples with occluded future boxes: {occluded_samples} | Total occluded box-timestep instances: {occluded_box_timesteps}')
+
+ planning_results = planning_metrics.compute()
+ planning_metrics.reset()
+ from prettytable import PrettyTable
+ planning_tab = PrettyTable()
+ metric_dict = {}
+
+ planning_tab.field_names = [
+ "metrics", "0.5s", "1.0s", "1.5s", "2.0s", "2.5s", "3.0s", "avg"]
+ for key in planning_results.keys():
+ value = planning_results[key].tolist()
+ new_values = []
+ for i in range(len(value)):
+ new_values.append(np.array(value[:i+1]).mean())
+ value = new_values
+ avg = [value[1], value[3], value[5]]
+ avg = sum(avg) / len(avg)
+ value.append(avg)
+ metric_dict[key] = avg
+ row_value = []
+ row_value.append(key)
+ for i in range(len(value)):
+ if 'col' in key:
+ row_value.append('%.3f' % float(value[i]*100) + '%')
+ else:
+ row_value.append('%.4f' % float(value[i]))
+ planning_tab.add_row(row_value)
+
+ print_log('\n'+str(planning_tab), logger=logger)
+ return metric_dict
diff --git a/projects/mmdet3d_plugin/datasets/map_utils/nuscmap_extractor.py b/projects/mmdet3d_plugin/datasets/map_utils/nuscmap_extractor.py
new file mode 100644
index 0000000..a3792d4
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/map_utils/nuscmap_extractor.py
@@ -0,0 +1,159 @@
+from shapely.geometry import LineString, box, Polygon
+from shapely import ops, strtree
+
+import numpy as np
+from nuscenes.map_expansion.map_api import NuScenesMap, NuScenesMapExplorer
+from nuscenes.eval.common.utils import quaternion_yaw
+from pyquaternion import Quaternion
+from .utils import split_collections, get_drivable_area_contour, \
+ get_ped_crossing_contour
+from numpy.typing import NDArray
+from typing import Dict, List, Tuple, Union
+
+class NuscMapExtractor(object):
+ """NuScenes map ground-truth extractor.
+
+ Args:
+ data_root (str): path to nuScenes dataset
+ roi_size (tuple or list): bev range
+ """
+ def __init__(self, data_root: str, roi_size: Union[List, Tuple]) -> None:
+ self.roi_size = roi_size
+ self.MAPS = ['boston-seaport', 'singapore-hollandvillage',
+ 'singapore-onenorth', 'singapore-queenstown']
+
+ self.nusc_maps = {}
+ self.map_explorer = {}
+ for loc in self.MAPS:
+ self.nusc_maps[loc] = NuScenesMap(
+ dataroot=data_root, map_name=loc)
+ self.map_explorer[loc] = NuScenesMapExplorer(self.nusc_maps[loc])
+
+ # local patch in nuScenes format
+ self.local_patch = box(-roi_size[0] / 2, -roi_size[1] / 2,
+ roi_size[0] / 2, roi_size[1] / 2)
+
+ def _union_ped(self, ped_geoms: List[Polygon]) -> List[Polygon]:
+ ''' merge close ped crossings.
+
+ Args:
+ ped_geoms (list): list of Polygon
+
+ Returns:
+ union_ped_geoms (Dict): merged ped crossings
+ '''
+
+ def get_rec_direction(geom):
+ rect = geom.minimum_rotated_rectangle
+ rect_v_p = np.array(rect.exterior.coords)[:3]
+ rect_v = rect_v_p[1:]-rect_v_p[:-1]
+ v_len = np.linalg.norm(rect_v, axis=-1)
+ longest_v_i = v_len.argmax()
+
+ return rect_v[longest_v_i], v_len[longest_v_i]
+
+ tree = strtree.STRtree(ped_geoms)
+ index_by_id = dict((id(pt), i) for i, pt in enumerate(ped_geoms))
+
+ final_pgeom = []
+ remain_idx = [i for i in range(len(ped_geoms))]
+ for i, pgeom in enumerate(ped_geoms):
+
+ if i not in remain_idx:
+ continue
+ # update
+ remain_idx.pop(remain_idx.index(i))
+ pgeom_v, pgeom_v_norm = get_rec_direction(pgeom)
+ final_pgeom.append(pgeom)
+
+ for o in tree.query(pgeom):
+ o_idx = index_by_id[id(o)]
+ if o_idx not in remain_idx:
+ continue
+
+ o_v, o_v_norm = get_rec_direction(o)
+ cos = pgeom_v.dot(o_v)/(pgeom_v_norm*o_v_norm)
+ if 1 - np.abs(cos) < 0.01: # theta < 8 degrees.
+ final_pgeom[-1] =\
+ final_pgeom[-1].union(o)
+ # update
+ remain_idx.pop(remain_idx.index(o_idx))
+
+ results = []
+ for p in final_pgeom:
+ results.extend(split_collections(p))
+ return results
+
+ def get_map_geom(self,
+ location: str,
+ translation: Union[List, NDArray],
+ rotation: Union[List, NDArray]) -> Dict[str, List[Union[LineString, Polygon]]]:
+ ''' Extract geometries given `location` and self pose, self may be lidar or ego.
+
+ Args:
+ location (str): city name
+ translation (array): self2global translation, shape (3,)
+ rotation (array): self2global quaternion, shape (4, )
+
+ Returns:
+ geometries (Dict): extracted geometries by category.
+ '''
+
+ # (center_x, center_y, len_y, len_x) in nuscenes format
+ patch_box = (translation[0], translation[1],
+ self.roi_size[1], self.roi_size[0])
+ rotation = Quaternion(rotation)
+ yaw = quaternion_yaw(rotation) / np.pi * 180
+
+ # get dividers
+ lane_dividers = self.map_explorer[location]._get_layer_line(
+ patch_box, yaw, 'lane_divider')
+
+ road_dividers = self.map_explorer[location]._get_layer_line(
+ patch_box, yaw, 'road_divider')
+
+ all_dividers = []
+ for line in lane_dividers + road_dividers:
+ all_dividers += split_collections(line)
+
+ # get ped crossings
+ ped_crossings = []
+ ped = self.map_explorer[location]._get_layer_polygon(
+ patch_box, yaw, 'ped_crossing')
+
+ for p in ped:
+ ped_crossings += split_collections(p)
+ # some ped crossings are split into several small parts
+ # we need to merge them
+ ped_crossings = self._union_ped(ped_crossings)
+
+ ped_crossing_lines = []
+ for p in ped_crossings:
+ # extract exteriors to get a closed polyline
+ line = get_ped_crossing_contour(p, self.local_patch)
+ if line is not None:
+ ped_crossing_lines.append(line)
+
+ # get boundaries
+ # we take the union of road segments and lanes as drivable areas
+ # we don't take drivable area layer in nuScenes since its definition may be ambiguous
+ road_segments = self.map_explorer[location]._get_layer_polygon(
+ patch_box, yaw, 'road_segment')
+ lanes = self.map_explorer[location]._get_layer_polygon(
+ patch_box, yaw, 'lane')
+ union_roads = ops.unary_union(road_segments)
+ union_lanes = ops.unary_union(lanes)
+ drivable_areas = ops.unary_union([union_roads, union_lanes])
+
+ drivable_areas = split_collections(drivable_areas)
+
+ # boundaries are defined as the contour of drivable areas
+ boundaries = get_drivable_area_contour(drivable_areas, self.roi_size)
+
+ return dict(
+ divider=all_dividers, # List[LineString]
+ ped_crossing=ped_crossing_lines, # List[LineString]
+ boundary=boundaries, # List[LineString]
+ drivable_area=drivable_areas, # List[Polygon],
+ )
+
diff --git a/projects/mmdet3d_plugin/datasets/map_utils/utils.py b/projects/mmdet3d_plugin/datasets/map_utils/utils.py
new file mode 100644
index 0000000..7dac57a
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/map_utils/utils.py
@@ -0,0 +1,119 @@
+from shapely.geometry import LineString, box, Polygon, LinearRing
+from shapely.geometry.base import BaseGeometry
+from shapely import ops
+import numpy as np
+from scipy.spatial import distance
+from typing import List, Optional, Tuple
+from numpy.typing import NDArray
+
+def split_collections(geom: BaseGeometry) -> List[Optional[BaseGeometry]]:
+ ''' Split Multi-geoms to list and check is valid or is empty.
+
+ Args:
+ geom (BaseGeometry): geoms to be split or validate.
+
+ Returns:
+ geometries (List): list of geometries.
+ '''
+ assert geom.geom_type in ['MultiLineString', 'LineString', 'MultiPolygon',
+ 'Polygon', 'GeometryCollection'], f"got geom type {geom.geom_type}"
+ if 'Multi' in geom.geom_type:
+ outs = []
+ for g in geom.geoms:
+ if g.is_valid and not g.is_empty:
+ outs.append(g)
+ return outs
+ else:
+ if geom.is_valid and not geom.is_empty:
+ return [geom,]
+ else:
+ return []
+
+def get_drivable_area_contour(drivable_areas: List[Polygon],
+ roi_size: Tuple) -> List[LineString]:
+ ''' Extract drivable area contours to get list of boundaries.
+
+ Args:
+ drivable_areas (list): list of drivable areas.
+ roi_size (tuple): bev range size
+
+ Returns:
+ boundaries (List): list of boundaries.
+ '''
+ max_x = roi_size[0] / 2
+ max_y = roi_size[1] / 2
+
+ # a bit smaller than roi to avoid unexpected boundaries on edges
+ local_patch = box(-max_x + 0.2, -max_y + 0.2, max_x - 0.2, max_y - 0.2)
+
+ exteriors = []
+ interiors = []
+
+ for poly in drivable_areas:
+ exteriors.append(poly.exterior)
+ for inter in poly.interiors:
+ interiors.append(inter)
+
+ results = []
+ for ext in exteriors:
+ # NOTE: we make sure all exteriors are clock-wise
+ # such that each boundary's right-hand-side is drivable area
+ # and left-hand-side is walk way
+
+ if ext.is_ccw:
+ ext = LinearRing(list(ext.coords)[::-1])
+ lines = ext.intersection(local_patch)
+ if lines.geom_type == 'MultiLineString':
+ lines = ops.linemerge(lines)
+ assert lines.geom_type in ['MultiLineString', 'LineString']
+
+ results.extend(split_collections(lines))
+
+ for inter in interiors:
+ # NOTE: we make sure all interiors are counter-clock-wise
+ if not inter.is_ccw:
+ inter = LinearRing(list(inter.coords)[::-1])
+ lines = inter.intersection(local_patch)
+ if lines.geom_type == 'MultiLineString':
+ lines = ops.linemerge(lines)
+ assert lines.geom_type in ['MultiLineString', 'LineString']
+
+ results.extend(split_collections(lines))
+
+ return results
+
+def get_ped_crossing_contour(polygon: Polygon,
+ local_patch: box) -> Optional[LineString]:
+ ''' Extract ped crossing contours to get a closed polyline.
+ Different from `get_drivable_area_contour`, this function ensures a closed polyline.
+
+ Args:
+ polygon (Polygon): ped crossing polygon to be extracted.
+ local_patch (tuple): local patch params
+
+ Returns:
+ line (LineString): a closed line
+ '''
+
+ ext = polygon.exterior
+ if not ext.is_ccw:
+ ext = LinearRing(list(ext.coords)[::-1])
+ lines = ext.intersection(local_patch)
+ if lines.type != 'LineString':
+ # remove points in intersection results
+ lines = [l for l in lines.geoms if l.geom_type != 'Point']
+ lines = ops.linemerge(lines)
+
+ # same instance but not connected.
+ if lines.type != 'LineString':
+ ls = []
+ for l in lines.geoms:
+ ls.append(np.array(l.coords))
+
+ lines = np.concatenate(ls, axis=0)
+ lines = LineString(lines)
+ if not lines.is_empty:
+ return lines
+
+ return None
+
diff --git a/projects/mmdet3d_plugin/datasets/nuscenes_3d_dataset.py b/projects/mmdet3d_plugin/datasets/nuscenes_3d_dataset.py
new file mode 100644
index 0000000..1dbdf9c
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/nuscenes_3d_dataset.py
@@ -0,0 +1,1529 @@
+import random
+import math
+import os
+from os import path as osp
+import cv2
+import tempfile
+import copy
+import prettytable
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+import pyquaternion
+from shapely.geometry import LineString
+from nuscenes.utils.data_classes import Box as NuScenesBox
+from nuscenes.eval.detection.config import config_factory as det_configs
+from nuscenes.eval.common.config import config_factory as track_configs
+
+import mmcv
+from mmcv.utils import print_log
+from mmdet.datasets import DATASETS
+from mmdet.datasets.pipelines import Compose
+from .utils import (
+ draw_lidar_bbox3d_on_img,
+ draw_lidar_bbox3d_on_bev,
+)
+
+
+@DATASETS.register_module()
+class NuScenes3DDataset(Dataset):
+ DefaultAttribute = {
+ "car": "vehicle.parked",
+ "pedestrian": "pedestrian.moving",
+ "trailer": "vehicle.parked",
+ "truck": "vehicle.parked",
+ "bus": "vehicle.moving",
+ "motorcycle": "cycle.without_rider",
+ "construction_vehicle": "vehicle.parked",
+ "bicycle": "cycle.without_rider",
+ "barrier": "",
+ "traffic_cone": "",
+ }
+ ErrNameMapping = {
+ "trans_err": "mATE",
+ "scale_err": "mASE",
+ "orient_err": "mAOE",
+ "vel_err": "mAVE",
+ "attr_err": "mAAE",
+ }
+ CLASSES = (
+ "car",
+ "truck",
+ "trailer",
+ "bus",
+ "construction_vehicle",
+ "bicycle",
+ "motorcycle",
+ "pedestrian",
+ "traffic_cone",
+ "barrier",
+ )
+ MAP_CLASSES = (
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+ )
+ ID_COLOR_MAP = [
+ (59, 59, 238),
+ (0, 255, 0),
+ (0, 0, 255),
+ (255, 255, 0),
+ (0, 255, 255),
+ (255, 0, 255),
+ (255, 255, 255),
+ (0, 127, 255),
+ (71, 130, 255),
+ (127, 127, 0),
+ ]
+
+ def __init__(
+ self,
+ ann_file,
+ pipeline=None,
+ data_root=None,
+ classes=None,
+ map_classes=None,
+ load_interval=1,
+ with_velocity=True,
+ with_visibility=True,
+ modality=None,
+ test_mode=False,
+ det3d_eval_version="detection_cvpr_2019",
+ track3d_eval_version="tracking_nips_2019",
+ version="v1.0-trainval",
+ use_valid_flag=False,
+ use_gt_mask=True,
+ vis_score_threshold=0.25,
+ data_aug_conf=None,
+ sequences_split_num=1,
+ with_seq_flag=False,
+ keep_consistent_seq_aug=True,
+ work_dir=None,
+ eval_config=None,
+ ):
+ self.version = version
+ self.load_interval = load_interval
+ self.use_valid_flag = use_valid_flag
+ self.use_gt_mask = use_gt_mask
+ super().__init__()
+ self.data_root = data_root
+ self.ann_file = ann_file
+ self.test_mode = test_mode
+ self.modality = modality
+ self.box_mode_3d = 0
+
+ if classes is not None:
+ self.CLASSES = classes
+ if map_classes is not None:
+ self.MAP_CLASSES = map_classes
+ self.cat2id = {name: i for i, name in enumerate(self.CLASSES)}
+ self.data_infos = self.load_annotations(self.ann_file)
+
+ if pipeline is not None:
+ self.pipeline = Compose(pipeline)
+
+ self.with_velocity = with_velocity
+ self.with_visibility = with_visibility
+ self.det3d_eval_version = det3d_eval_version
+ self.det3d_eval_configs = det_configs(self.det3d_eval_version)
+ self.det3d_eval_configs.class_names = list(self.det3d_eval_configs.class_range.keys())
+ self.track3d_eval_version = track3d_eval_version
+ self.track3d_eval_configs = track_configs(self.track3d_eval_version)
+ self.track3d_eval_configs.class_names = list(self.track3d_eval_configs.class_range.keys())
+ if self.modality is None:
+ self.modality = dict(
+ use_camera=False,
+ use_lidar=True,
+ use_radar=False,
+ use_map=False,
+ use_external=False,
+ )
+ self.vis_score_threshold = vis_score_threshold
+
+ self.data_aug_conf = data_aug_conf
+ self.sequences_split_num = sequences_split_num
+ self.keep_consistent_seq_aug = keep_consistent_seq_aug
+ if with_seq_flag:
+ self._set_sequence_group_flag()
+
+ self.work_dir = work_dir
+ self.eval_config = eval_config
+
+ def __len__(self):
+ return len(self.data_infos)
+
+ def _set_sequence_group_flag(self):
+ """
+ Set each sequence to be a different group
+ """
+ if self.sequences_split_num == -1:
+ self.flag = np.arange(len(self.data_infos))
+ return
+
+ res = []
+
+ curr_sequence = 0
+ for idx in range(len(self.data_infos)):
+ if idx != 0 and len(self.data_infos[idx]["sweeps"]) == 0:
+ # Not first frame and # of sweeps is 0 -> new sequence
+ curr_sequence += 1
+ res.append(curr_sequence)
+
+ self.flag = np.array(res, dtype=np.int64)
+
+ if self.sequences_split_num != 1:
+ if self.sequences_split_num == "all":
+ self.flag = np.array(
+ range(len(self.data_infos)), dtype=np.int64
+ )
+ else:
+ bin_counts = np.bincount(self.flag)
+ new_flags = []
+ curr_new_flag = 0
+ for curr_flag in range(len(bin_counts)):
+ curr_sequence_length = np.array(
+ list(
+ range(
+ 0,
+ bin_counts[curr_flag],
+ math.ceil(
+ bin_counts[curr_flag]
+ / self.sequences_split_num
+ ),
+ )
+ )
+ + [bin_counts[curr_flag]]
+ )
+
+ for sub_seq_idx in (
+ curr_sequence_length[1:] - curr_sequence_length[:-1]
+ ):
+ for _ in range(sub_seq_idx):
+ new_flags.append(curr_new_flag)
+ curr_new_flag += 1
+
+ assert len(new_flags) == len(self.flag)
+ assert (
+ len(np.bincount(new_flags))
+ == len(np.bincount(self.flag)) * self.sequences_split_num
+ )
+ self.flag = np.array(new_flags, dtype=np.int64)
+
+ def get_augmentation(self):
+ if self.data_aug_conf is None:
+ return None
+ H, W = self.data_aug_conf["H"], self.data_aug_conf["W"]
+ fH, fW = self.data_aug_conf["final_dim"]
+ if not self.test_mode:
+ resize = np.random.uniform(*self.data_aug_conf["resize_lim"])
+ resize_dims = (int(W * resize), int(H * resize))
+ newW, newH = resize_dims
+ crop_h = (
+ int(
+ (1 - np.random.uniform(*self.data_aug_conf["bot_pct_lim"]))
+ * newH
+ )
+ - fH
+ )
+ crop_w = int(np.random.uniform(0, max(0, newW - fW)))
+ crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
+ flip = False
+ if self.data_aug_conf["rand_flip"] and np.random.choice([0, 1]):
+ flip = True
+ rotate = np.random.uniform(*self.data_aug_conf["rot_lim"])
+ rotate_3d = np.random.uniform(*self.data_aug_conf["rot3d_range"])
+ else:
+ resize = max(fH / H, fW / W)
+ resize_dims = (int(W * resize), int(H * resize))
+ newW, newH = resize_dims
+ crop_h = (
+ int((1 - np.mean(self.data_aug_conf["bot_pct_lim"])) * newH)
+ - fH
+ )
+ crop_w = int(max(0, newW - fW) / 2)
+ crop = (crop_w, crop_h, crop_w + fW, crop_h + fH)
+ flip = False
+ rotate = 0
+ rotate_3d = 0
+ aug_config = {
+ "resize": resize,
+ "resize_dims": resize_dims,
+ "crop": crop,
+ "flip": flip,
+ "rotate": rotate,
+ "rotate_3d": rotate_3d,
+ }
+ return aug_config
+
+ def __getitem__(self, idx):
+ if isinstance(idx, dict):
+ aug_config = idx["aug_config"]
+ idx = idx["idx"]
+ else:
+ aug_config = self.get_augmentation()
+ data = self.get_data_info(idx)
+ data["aug_config"] = aug_config
+ data = self.pipeline(data)
+ return data
+
+ def get_cat_ids(self, idx):
+ info = self.data_infos[idx]
+ if self.use_valid_flag and self.use_gt_mask:
+ mask = info["valid_flag"]
+ gt_names = set(info["gt_names"][mask])
+ else:
+ gt_names = set(info["gt_names"])
+
+ cat_ids = []
+ for name in gt_names:
+ if name in self.CLASSES:
+ cat_ids.append(self.cat2id[name])
+ return cat_ids
+
+ def load_annotations(self, ann_file):
+ data = mmcv.load(ann_file, file_format="pkl")
+ data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+ data_infos = data_infos[:: self.load_interval]
+ self.metadata = data["metadata"]
+ self.version = self.metadata["version"]
+ print(self.metadata)
+ return data_infos
+
+ def anno2geom(self, annos):
+ map_geoms = {}
+ for label, anno_list in annos.items():
+ map_geoms[label] = []
+ for anno in anno_list:
+ geom = LineString(anno)
+ map_geoms[label].append(geom)
+ return map_geoms
+
+ def get_data_info(self, index):
+ info = self.data_infos[index]
+ input_dict = dict(
+ token=info["token"],
+ map_location=info["map_location"],
+ pts_filename=info["lidar_path"],
+ sweeps=info["sweeps"],
+ timestamp=info["timestamp"] / 1e6,
+ lidar2ego_translation=info["lidar2ego_translation"],
+ lidar2ego_rotation=info["lidar2ego_rotation"],
+ ego2global_translation=info["ego2global_translation"],
+ ego2global_rotation=info["ego2global_rotation"],
+ ego_status=info['ego_status'].astype(np.float32),
+ map_infos=info["map_annos"],
+ )
+ lidar2ego = np.eye(4)
+ lidar2ego[:3, :3] = pyquaternion.Quaternion(
+ info["lidar2ego_rotation"]
+ ).rotation_matrix
+ lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"])
+ ego2global = np.eye(4)
+ ego2global[:3, :3] = pyquaternion.Quaternion(
+ info["ego2global_rotation"]
+ ).rotation_matrix
+ ego2global[:3, 3] = np.array(info["ego2global_translation"])
+ input_dict["lidar2global"] = ego2global @ lidar2ego
+
+ map_geoms = self.anno2geom(info["map_annos"])
+ input_dict["map_geoms"] = map_geoms
+
+ if self.modality["use_camera"]:
+ image_paths = []
+ lidar2img_rts = []
+ lidar2cam_rts = []
+ cam_intrinsic = []
+ for cam_type, cam_info in info["cams"].items():
+ image_paths.append(cam_info["data_path"])
+ # obtain lidar to image transformation matrix
+ lidar2cam_r = np.linalg.inv(cam_info["sensor2lidar_rotation"])
+ lidar2cam_t = (
+ cam_info["sensor2lidar_translation"] @ lidar2cam_r.T
+ )
+ lidar2cam_rt = np.eye(4)
+ lidar2cam_rt[:3, :3] = lidar2cam_r.T
+ lidar2cam_rt[3, :3] = -lidar2cam_t
+ intrinsic = copy.deepcopy(cam_info["cam_intrinsic"])
+ cam_intrinsic.append(intrinsic)
+ viewpad = np.eye(4)
+ viewpad[: intrinsic.shape[0], : intrinsic.shape[1]] = intrinsic
+ lidar2img_rt = viewpad @ lidar2cam_rt.T
+ lidar2img_rts.append(lidar2img_rt)
+ lidar2cam_rts.append(lidar2cam_rt)
+
+ input_dict.update(
+ dict(
+ img_filename=image_paths,
+ lidar2img=lidar2img_rts,
+ lidar2cam=lidar2cam_rts,
+ cam_intrinsic=cam_intrinsic,
+ )
+ )
+
+ annos = self.get_ann_info(index)
+ input_dict.update(annos)
+ return input_dict
+
+ def get_ann_info(self, index):
+ info = self.data_infos[index]
+ if self.use_gt_mask:
+ if self.use_valid_flag:
+ mask = info["valid_flag"]
+ else:
+ mask = info["num_lidar_pts"] > 0
+ else:
+ mask = np.ones(len(info["gt_boxes"]), dtype=bool)
+ gt_bboxes_3d = info["gt_boxes"][mask]
+ gt_names_3d = info["gt_names"][mask]
+ gt_labels_3d = []
+ for cat in gt_names_3d:
+ if cat in self.CLASSES:
+ gt_labels_3d.append(self.CLASSES.index(cat))
+ else:
+ gt_labels_3d.append(-1)
+ gt_labels_3d = np.array(gt_labels_3d)
+
+ if self.with_velocity:
+ gt_velocity = info["gt_velocity"][mask]
+ nan_mask = np.isnan(gt_velocity[:, 0])
+ gt_velocity[nan_mask] = [0.0, 0.0]
+ gt_bboxes_3d = np.concatenate([gt_bboxes_3d, gt_velocity], axis=-1)
+
+ if self.with_visibility:
+ gt_visibility = (info["num_lidar_pts"][mask] > 0).astype(np.float32)
+
+ anns_results = dict(
+ gt_bboxes_3d=gt_bboxes_3d,
+ gt_labels_3d=gt_labels_3d,
+ gt_names=gt_names_3d,
+ gt_visibility=gt_visibility,
+ )
+ if "instance_inds" in info:
+ instance_inds = np.array(info["instance_inds"], dtype=np.int)[mask]
+ anns_results["instance_inds"] = instance_inds
+
+ if 'gt_agent_fut_trajs' in info:
+ anns_results['gt_agent_fut_trajs'] = info['gt_agent_fut_trajs'][mask]
+ anns_results['gt_agent_fut_masks'] = info['gt_agent_fut_masks'][mask]
+
+ if 'gt_ego_fut_trajs' in info:
+ anns_results['gt_ego_fut_trajs'] = info['gt_ego_fut_trajs']
+ anns_results['gt_ego_fut_masks'] = info['gt_ego_fut_masks']
+ anns_results['gt_ego_fut_cmd'] = info['gt_ego_fut_cmd']
+
+ ## get future box for planning eval
+ fut_ts = int(info['gt_ego_fut_masks'].sum())
+ fut_boxes = []
+ fut_boxes_occluded = []
+ cur_scene_token = info["scene_token"]
+ cur_T_global = get_T_global(info)
+ for i in range(1, fut_ts + 1):
+ fut_info = self.data_infos[index + i]
+ fut_scene_token = fut_info["scene_token"]
+ if cur_scene_token != fut_scene_token:
+ break
+ if self.use_gt_mask:
+ if self.use_valid_flag:
+ mask = fut_info["valid_flag"]
+ else:
+ mask = fut_info["num_lidar_pts"] > 0
+ else:
+ mask = np.ones(len(fut_info["gt_boxes"]), dtype=bool)
+ if self.use_valid_flag:
+ occluded_mask = ~fut_info["valid_flag"]
+ else:
+ occluded_mask = ~(fut_info["num_lidar_pts"] > 0)
+
+ fut_gt_bboxes_3d = fut_info["gt_boxes"][mask]
+ fut_gt_bboxes_occluded = fut_info["gt_boxes"][occluded_mask]
+
+ fut_T_global = get_T_global(fut_info)
+ T_fut2cur = np.linalg.inv(cur_T_global) @ fut_T_global
+
+ for bboxes in (fut_gt_bboxes_3d, fut_gt_bboxes_occluded):
+ if len(bboxes):
+ center = bboxes[:, :3] @ T_fut2cur[:3, :3].T + T_fut2cur[:3, 3]
+ yaw = np.stack([np.cos(bboxes[:, 6]), np.sin(bboxes[:, 6])], axis=-1)
+ yaw = yaw @ T_fut2cur[:2, :2].T
+ bboxes[:, :3] = center
+ bboxes[:, 6] = np.arctan2(yaw[..., 1], yaw[..., 0])
+
+ fut_boxes.append(fut_gt_bboxes_3d)
+ fut_boxes_occluded.append(fut_gt_bboxes_occluded)
+
+ anns_results['fut_boxes'] = fut_boxes
+ anns_results['fut_boxes_occluded'] = fut_boxes_occluded
+
+ return anns_results
+
+ def _format_bbox(self, results, jsonfile_prefix=None, tracking=False):
+ nusc_annos = {}
+ mapped_class_names = self.CLASSES
+
+ print("Start to convert detection format...")
+ for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
+ annos = []
+ boxes = output_to_nusc_box(
+ det, threshold=self.tracking_threshold if tracking else None
+ )
+ sample_token = self.data_infos[sample_id]["token"]
+ boxes = lidar_nusc_box_to_global(
+ self.data_infos[sample_id],
+ boxes,
+ mapped_class_names,
+ self.det3d_eval_configs,
+ self.det3d_eval_version,
+ )
+ for i, box in enumerate(boxes):
+ name = mapped_class_names[box.label]
+ if tracking and name in [
+ "barrier",
+ "traffic_cone",
+ "construction_vehicle",
+ ]:
+ continue
+ if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
+ if name in [
+ "car",
+ "construction_vehicle",
+ "bus",
+ "truck",
+ "trailer",
+ ]:
+ attr = "vehicle.moving"
+ elif name in ["bicycle", "motorcycle"]:
+ attr = "cycle.with_rider"
+ else:
+ attr = NuScenes3DDataset.DefaultAttribute[name]
+ else:
+ if name in ["pedestrian"]:
+ attr = "pedestrian.standing"
+ elif name in ["bus"]:
+ attr = "vehicle.stopped"
+ else:
+ attr = NuScenes3DDataset.DefaultAttribute[name]
+
+ nusc_anno = dict(
+ sample_token=sample_token,
+ translation=box.center.tolist(),
+ size=box.wlh.tolist(),
+ rotation=box.orientation.elements.tolist(),
+ velocity=box.velocity[:2].tolist(),
+ )
+ if not tracking:
+ nusc_anno.update(
+ dict(
+ detection_name=name,
+ detection_score=box.score,
+ attribute_name=attr,
+ )
+ )
+ else:
+ nusc_anno.update(
+ dict(
+ tracking_name=name,
+ tracking_score=box.score,
+ tracking_id=str(box.token),
+ )
+ )
+
+ annos.append(nusc_anno)
+ nusc_annos[sample_token] = annos
+ nusc_submissions = {
+ "meta": self.modality,
+ "results": nusc_annos,
+ }
+
+ mmcv.mkdir_or_exist(jsonfile_prefix)
+ filename = "results_nusc_tracking.json" if tracking else "results_nusc.json"
+ res_path = osp.join(jsonfile_prefix, filename)
+ print("Results writes to", res_path)
+ mmcv.dump(nusc_submissions, res_path)
+ return res_path
+
+ def _evaluate_single(
+ self, result_path, logger=None, result_name="img_bbox", tracking=False
+ ):
+ from nuscenes import NuScenes
+
+ output_dir = osp.join(*osp.split(result_path)[:-1])
+ nusc = NuScenes(
+ version=self.version, dataroot=self.data_root, verbose=False
+ )
+ eval_set_map = {
+ "v1.0-mini": "mini_val",
+ "v1.0-trainval": "val",
+ }
+ if not tracking:
+ from nuscenes.eval.detection.evaluate import NuScenesEval
+
+ nusc_eval = NuScenesEval(
+ nusc,
+ config=self.det3d_eval_configs,
+ result_path=result_path,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=True,
+ )
+ nusc_eval.main(render_curves=False)
+
+ # record metrics
+ metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json"))
+ detail = dict()
+ metric_prefix = f"{result_name}_NuScenes"
+ for name in self.CLASSES:
+ for k, v in metrics["label_aps"][name].items():
+ val = float("{:.4f}".format(v))
+ detail[
+ "{}/{}_AP_dist_{}".format(metric_prefix, name, k)
+ ] = val
+ for k, v in metrics["label_tp_errors"][name].items():
+ val = float("{:.4f}".format(v))
+ detail["{}/{}_{}".format(metric_prefix, name, k)] = val
+ for k, v in metrics["tp_errors"].items():
+ val = float("{:.4f}".format(v))
+ detail[
+ "{}/{}".format(metric_prefix, self.ErrNameMapping[k])
+ ] = val
+
+ detail["{}/NDS".format(metric_prefix)] = metrics["nd_score"]
+ detail["{}/mAP".format(metric_prefix)] = metrics["mean_ap"]
+ else:
+ from nuscenes.eval.tracking.evaluate import TrackingEval
+
+ nusc_eval = TrackingEval(
+ config=self.track3d_eval_configs,
+ result_path=result_path,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=True,
+ nusc_version=self.version,
+ nusc_dataroot=self.data_root,
+ )
+ metrics = nusc_eval.main()
+
+ # record metrics
+ metrics = mmcv.load(osp.join(output_dir, "metrics_summary.json"))
+ print(metrics)
+ detail = dict()
+ metric_prefix = f"{result_name}_NuScenes"
+ keys = [
+ "amota",
+ "amotp",
+ "recall",
+ "motar",
+ "gt",
+ "mota",
+ "motp",
+ "mt",
+ "ml",
+ "faf",
+ "tp",
+ "fp",
+ "fn",
+ "ids",
+ "frag",
+ "tid",
+ "lgd",
+ ]
+ for key in keys:
+ detail["{}/{}".format(metric_prefix, key)] = metrics[key]
+
+ return detail
+
+ def format_results(self, results, jsonfile_prefix=None, tracking=False):
+ assert isinstance(results, list), "results must be a list"
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, "results")
+ else:
+ tmp_dir = None
+
+ if not ("pts_bbox" in results[0] or "img_bbox" in results[0]):
+ result_files = self._format_bbox(
+ results, jsonfile_prefix, tracking=tracking
+ )
+ else:
+ result_files = dict()
+ for name in results[0]:
+ print(f"\nFormating bboxes of {name}")
+ results_ = [out[name] for out in results]
+ tmp_file_ = jsonfile_prefix
+ result_files.update(
+ {
+ name: self._format_bbox(
+ results_, tmp_file_, tracking=tracking
+ )
+ }
+ )
+ return result_files, tmp_dir
+
+ def format_map_results(self, results, prefix=None):
+ submissions = {'results': {},}
+
+ for j, pred in enumerate(results):
+ '''
+ For each case, the result should be formatted as Dict{'vectors': [], 'scores': [], 'labels': []}
+ 'vectors': List of vector, each vector is a array([[x1, y1], [x2, y2] ...]),
+ contain all vectors predicted in this sample.
+ 'scores: List of score(float),
+ contain scores of all instances in this sample.
+ 'labels': List of label(int),
+ contain labels of all instances in this sample.
+ '''
+ if pred is None: # empty prediction
+ continue
+ pred = pred['img_bbox']
+
+ single_case = {'vectors': [], 'scores': [], 'labels': []}
+ token = self.data_infos[j]['token']
+ for i in range(len(pred['scores'])):
+ score = pred['scores'][i]
+ label = pred['labels'][i]
+ vector = pred['vectors'][i]
+
+ # A line should have >=2 points
+ if len(vector) < 2:
+ continue
+
+ single_case['vectors'].append(vector)
+ single_case['scores'].append(score)
+ single_case['labels'].append(label)
+
+ submissions['results'][token] = single_case
+
+ out_path = osp.join(prefix, 'submission_vector.json')
+ print(f'saving submissions results to {out_path}')
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ mmcv.dump(submissions, out_path)
+ return out_path
+
+ def format_motion_results(self, results, jsonfile_prefix=None, tracking=False, thresh=None):
+ nusc_annos = {}
+ mapped_class_names = self.CLASSES
+
+ print("Start to convert detection format...")
+ for sample_id, det in enumerate(mmcv.track_iter_progress(results)):
+ annos = []
+ boxes = output_to_nusc_box(
+ det['img_bbox'], threshold=None
+ )
+ sample_token = self.data_infos[sample_id]["token"]
+ boxes = lidar_nusc_box_to_global(
+ self.data_infos[sample_id],
+ boxes,
+ mapped_class_names,
+ self.det3d_eval_configs,
+ self.det3d_eval_version,
+ filter_with_cls_range=False,
+ )
+ for i, box in enumerate(boxes):
+ if thresh is not None and box.score < thresh:
+ continue
+ name = mapped_class_names[box.label]
+ if tracking and name in [
+ "barrier",
+ "traffic_cone",
+ "construction_vehicle",
+ ]:
+ continue
+ if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
+ if name in [
+ "car",
+ "construction_vehicle",
+ "bus",
+ "truck",
+ "trailer",
+ ]:
+ attr = "vehicle.moving"
+ elif name in ["bicycle", "motorcycle"]:
+ attr = "cycle.with_rider"
+ else:
+ attr = NuScenes3DDataset.DefaultAttribute[name]
+ else:
+ if name in ["pedestrian"]:
+ attr = "pedestrian.standing"
+ elif name in ["bus"]:
+ attr = "vehicle.stopped"
+ else:
+ attr = NuScenes3DDataset.DefaultAttribute[name]
+
+ nusc_anno = dict(
+ sample_token=sample_token,
+ translation=box.center.tolist(),
+ size=box.wlh.tolist(),
+ rotation=box.orientation.elements.tolist(),
+ velocity=box.velocity[:2].tolist(),
+ )
+ if not tracking:
+ nusc_anno.update(
+ dict(
+ detection_name=name,
+ detection_score=box.score,
+ attribute_name=attr,
+ )
+ )
+ else:
+ nusc_anno.update(
+ dict(
+ tracking_name=name,
+ tracking_score=box.score,
+ tracking_id=str(box.token),
+ )
+ )
+ nusc_anno.update(
+ dict(
+ trajs=det['img_bbox']['trajs_3d'][i].numpy(),
+ )
+ )
+ if 'trajs_score' in det['img_bbox']:
+ nusc_anno['trajs_score'] = det['img_bbox']['trajs_score'][i].numpy()
+ annos.append(nusc_anno)
+ nusc_annos[sample_token] = annos
+ nusc_submissions = {
+ "meta": self.modality,
+ "results": nusc_annos,
+ }
+
+ return nusc_submissions
+
+ def _evaluate_single_motion(self,
+ results,
+ result_path,
+ logger=None,
+ metric='bbox',
+ result_name='pts_bbox'):
+ """Evaluation for a single model in nuScenes protocol.
+
+ Args:
+ result_path (str): Path of the result file.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ metric (str): Metric name used for evaluation. Default: 'bbox'.
+ result_name (str): Result name in the metric prefix.
+ Default: 'pts_bbox'.
+
+ Returns:
+ dict: Dictionary of evaluation details.
+ """
+ from nuscenes import NuScenes
+ from .evaluation.motion.motion_eval_uniad import NuScenesEval as NuScenesEvalMotion
+
+ output_dir = result_path
+ nusc = NuScenes(
+ version=self.version, dataroot=self.data_root, verbose=False)
+ eval_set_map = {
+ 'v1.0-mini': 'mini_val',
+ 'v1.0-trainval': 'val',
+ }
+ nusc_eval = NuScenesEvalMotion(
+ nusc,
+ config=copy.deepcopy(self.det3d_eval_configs),
+ result_path=results,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=False,
+ seconds=6)
+ metrics = nusc_eval.main(render_curves=False)
+
+ MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err']
+ class_names = ['car', 'pedestrian']
+
+ table = prettytable.PrettyTable()
+ table.field_names = ["class names"] + MOTION_METRICS
+ for class_name in class_names:
+ row_data = [class_name]
+ for m in MOTION_METRICS:
+ row_data.append('%.4f' % metrics[f'{class_name}_{m}'])
+ table.add_row(row_data)
+ print_log('\n'+str(table), logger=logger)
+ return metrics
+
+ def _evaluate_single_det_occluded(self, result_path, logger=None, result_name='img_bbox'):
+ """Evaluate detection on occluded objects only (num_lidar_pts == 0)."""
+ from nuscenes import NuScenes
+ from .evaluation.det.occluded_det_eval import OccludedDetectionEval
+
+ output_dir = osp.join(osp.dirname(result_path), 'occluded_det')
+ nusc = NuScenes(version=self.version, dataroot=self.data_root, verbose=False)
+ eval_set_map = {
+ 'v1.0-mini': 'mini_val',
+ 'v1.0-trainval': 'val',
+ }
+ nusc_eval = OccludedDetectionEval(
+ nusc,
+ config=self.det3d_eval_configs,
+ result_path=result_path,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=False,
+ )
+ nusc_eval.main(render_curves=False)
+
+ metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json'))
+ detail = {}
+ for name in self.CLASSES:
+ for k, v in metrics['label_aps'].get(name, {}).items():
+ detail[f'occluded/{name}_AP_dist_{k}'] = float('{:.4f}'.format(v))
+ for k, v in metrics['label_tp_errors'].get(name, {}).items():
+ detail[f'occluded/{name}_{k}'] = float('{:.4f}'.format(v))
+ for k, v in metrics['tp_errors'].items():
+ detail[f'occluded/{self.ErrNameMapping[k]}'] = float('{:.4f}'.format(v))
+ detail['occluded/NDS'] = metrics['nd_score']
+ detail['occluded/mAP'] = metrics['mean_ap']
+ return detail
+
+ def _evaluate_single_motion_occluded(self,
+ results,
+ result_path,
+ logger=None):
+ """Evaluate motion prediction restricted to occluded objects (num_lidar_pts == 0)."""
+ from nuscenes import NuScenes
+ from .evaluation.motion.motion_eval_uniad import OccludedMotionEval
+
+ output_dir = osp.join(result_path, 'occluded_motion')
+ nusc = NuScenes(
+ version=self.version, dataroot=self.data_root, verbose=False)
+ eval_set_map = {
+ 'v1.0-mini': 'mini_val',
+ 'v1.0-trainval': 'val',
+ }
+ nusc_eval = OccludedMotionEval(
+ nusc,
+ config=copy.deepcopy(self.det3d_eval_configs),
+ result_path=results,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=False,
+ seconds=6)
+ metrics = nusc_eval.main(render_curves=False)
+
+ MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err']
+ class_names = ['car', 'pedestrian']
+
+ table = prettytable.PrettyTable()
+ table.field_names = ["class names (occluded)"] + MOTION_METRICS
+ for class_name in class_names:
+ row_data = [class_name]
+ for m in MOTION_METRICS:
+ row_data.append('%.4f' % metrics[f'{class_name}_{m}'])
+ table.add_row(row_data)
+ print_log('\n[Occluded Objects]\n' + str(table), logger=logger)
+
+ return {f'occluded/{k}': v for k, v in metrics.items()}
+
+ def _evaluate_single_det_visible(self, result_path, logger=None, result_name='img_bbox'):
+ """Evaluate detection on visible objects, ignoring predictions that match occluded GT.
+
+ This gives a fair vis/mAP comparison between a visible-only baseline and
+ a model trained to also predict occluded objects: detections of occluded
+ objects are not penalised as false positives.
+ """
+ from nuscenes import NuScenes
+ from .evaluation.det.occluded_det_eval import VisibleDetectionEval
+
+ output_dir = osp.join(osp.dirname(result_path), 'visible_det')
+ nusc = NuScenes(version=self.version, dataroot=self.data_root, verbose=False)
+ eval_set_map = {
+ 'v1.0-mini': 'mini_val',
+ 'v1.0-trainval': 'val',
+ }
+ nusc_eval = VisibleDetectionEval(
+ nusc,
+ config=self.det3d_eval_configs,
+ result_path=result_path,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=False,
+ )
+ nusc_eval.main(render_curves=False)
+
+ metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json'))
+ detail = {}
+ for name in self.CLASSES:
+ for k, v in metrics['label_aps'].get(name, {}).items():
+ detail[f'vis/{name}_AP_dist_{k}'] = float('{:.4f}'.format(v))
+ for k, v in metrics['label_tp_errors'].get(name, {}).items():
+ detail[f'vis/{name}_{k}'] = float('{:.4f}'.format(v))
+ for k, v in metrics['tp_errors'].items():
+ detail[f'vis/{self.ErrNameMapping[k]}'] = float('{:.4f}'.format(v))
+ detail['vis/NDS'] = metrics['nd_score']
+ detail['vis/mAP'] = metrics['mean_ap']
+ return detail
+
+ def _evaluate_single_det_all(self, result_path, logger=None, result_name='img_bbox'):
+ """Evaluate detection on all objects (visible + occluded)."""
+ from nuscenes import NuScenes
+ from .evaluation.det.occluded_det_eval import AllDetectionEval
+
+ output_dir = osp.join(osp.dirname(result_path), 'all_det')
+ nusc = NuScenes(version=self.version, dataroot=self.data_root, verbose=False)
+ eval_set_map = {
+ 'v1.0-mini': 'mini_val',
+ 'v1.0-trainval': 'val',
+ }
+ nusc_eval = AllDetectionEval(
+ nusc,
+ config=self.det3d_eval_configs,
+ result_path=result_path,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=False,
+ )
+ nusc_eval.main(render_curves=False)
+
+ metrics = mmcv.load(osp.join(output_dir, 'metrics_summary.json'))
+ detail = {}
+ for name in self.CLASSES:
+ for k, v in metrics['label_aps'].get(name, {}).items():
+ detail[f'all/{name}_AP_dist_{k}'] = float('{:.4f}'.format(v))
+ for k, v in metrics['label_tp_errors'].get(name, {}).items():
+ detail[f'all/{name}_{k}'] = float('{:.4f}'.format(v))
+ for k, v in metrics['tp_errors'].items():
+ detail[f'all/{self.ErrNameMapping[k]}'] = float('{:.4f}'.format(v))
+ detail['all/NDS'] = metrics['nd_score']
+ detail['all/mAP'] = metrics['mean_ap']
+ return detail
+
+ def _evaluate_single_motion_all(self, results, result_path, logger=None):
+ """Evaluate motion prediction on all objects (visible + occluded)."""
+ from nuscenes import NuScenes
+ from .evaluation.motion.motion_eval_uniad import AllMotionEval
+
+ output_dir = osp.join(result_path, 'all_motion')
+ nusc = NuScenes(
+ version=self.version, dataroot=self.data_root, verbose=False)
+ eval_set_map = {
+ 'v1.0-mini': 'mini_val',
+ 'v1.0-trainval': 'val',
+ }
+ nusc_eval = AllMotionEval(
+ nusc,
+ config=copy.deepcopy(self.det3d_eval_configs),
+ result_path=results,
+ eval_set=eval_set_map[self.version],
+ output_dir=output_dir,
+ verbose=False,
+ seconds=6)
+ metrics = nusc_eval.main(render_curves=False)
+
+ MOTION_METRICS = ['EPA', 'min_ade_err', 'min_fde_err', 'miss_rate_err']
+ class_names = ['car', 'pedestrian']
+
+ table = prettytable.PrettyTable()
+ table.field_names = ["class names (all)"] + MOTION_METRICS
+ for class_name in class_names:
+ row_data = [class_name]
+ for m in MOTION_METRICS:
+ row_data.append('%.4f' % metrics[f'{class_name}_{m}'])
+ table.add_row(row_data)
+ print_log('\n[All Objects]\n' + str(table), logger=logger)
+
+ return {f'all/{k}': v for k, v in metrics.items()}
+
+ def _evaluate_visibility_accuracy(self, results):
+ """Evaluate the visibility head calibration on matched GT boxes.
+
+ For each GT box (visible or occluded) we find the closest same-class
+ prediction within MATCH_DIST metres (BEV). We then compare the
+ prediction's sigmoid visibility score to the GT sensor-visibility flag
+ (num_lidar_pts > 0).
+
+ Returned metrics
+ ----------------
+ visibility/accuracy : fraction correct at 0.5 threshold
+ visibility/auroc : area under the ROC curve
+ visibility/accuracy_visible : accuracy on visible-GT-matched pairs
+ visibility/accuracy_occluded : accuracy on occluded-GT-matched pairs
+ visibility/n_matched : total matched pairs across the val set
+ visibility/n_visible : matched pairs where GT is visible
+ visibility/n_occluded : matched pairs where GT is occluded
+ """
+ MATCH_DIST = 4.0 # BEV centre-distance threshold in metres
+
+ all_scores = [] # predicted visibility scores (sigmoid)
+ all_targets = [] # GT sensor-visibility (0.0 / 1.0)
+
+ for i, result in enumerate(results):
+ det = result.get('img_bbox', result)
+ if 'visibility_scores' not in det:
+ return {} # head not enabled — skip entirely
+
+ vis_scores = det['visibility_scores'].numpy() # (N,)
+ boxes = det['boxes_3d'].numpy() # (N, ≥2)
+ pred_labels = det['labels_3d'].numpy() # (N,)
+
+ info = self.data_infos[i]
+ gt_boxes = info['gt_boxes'] # (M, 7) lidar frame
+ gt_names = info['gt_names'] # (M,)
+
+ if 'num_lidar_pts' in info:
+ gt_vis = (info['num_lidar_pts'] > 0).astype(np.float32)
+ elif 'valid_flag' in info:
+ gt_vis = info['valid_flag'].astype(np.float32)
+ else:
+ continue
+
+ if len(gt_boxes) == 0 or len(boxes) == 0:
+ continue
+
+ pred_centers = boxes[:, :2] # (N, 2) BEV
+ gt_centers = gt_boxes[:, :2] # (M, 2) BEV
+
+ for gi in range(len(gt_names)):
+ gt_cls = gt_names[gi]
+ if gt_cls not in self.CLASSES:
+ continue
+ gt_label = self.CLASSES.index(gt_cls)
+
+ cls_idx = np.where(pred_labels == gt_label)[0]
+ if len(cls_idx) == 0:
+ continue
+
+ dists = np.linalg.norm(pred_centers[cls_idx] - gt_centers[gi], axis=1)
+ nearest = dists.argmin()
+ if dists[nearest] <= MATCH_DIST:
+ all_scores.append(float(vis_scores[cls_idx[nearest]]))
+ all_targets.append(float(gt_vis[gi]))
+
+ if len(all_targets) < 2:
+ return {}
+
+ scores = np.array(all_scores, dtype=np.float64)
+ targets = np.array(all_targets, dtype=np.float64)
+ preds = (scores > 0.5).astype(np.float64)
+
+ accuracy = float((preds == targets).mean())
+
+ # AUROC via trapezoidal rule (no external dependency)
+ pos = targets.sum()
+ neg = len(targets) - pos
+ if pos > 0 and neg > 0:
+ order = np.argsort(scores)[::-1]
+ t_sorted = targets[order]
+ tprs = np.concatenate([[0.0], np.cumsum(t_sorted == 1) / pos, [1.0]])
+ fprs = np.concatenate([[0.0], np.cumsum(t_sorted == 0) / neg, [1.0]])
+ auroc = float(np.trapz(tprs, fprs))
+ else:
+ auroc = float('nan')
+
+ vis_mask = targets == 1.0
+ occ_mask = targets == 0.0
+ acc_vis = float((preds[vis_mask] == targets[vis_mask]).mean()) if vis_mask.any() else float('nan')
+ acc_occ = float((preds[occ_mask] == targets[occ_mask]).mean()) if occ_mask.any() else float('nan')
+
+ return {
+ 'visibility/accuracy': accuracy,
+ 'visibility/auroc': auroc,
+ 'visibility/accuracy_visible': acc_vis,
+ 'visibility/accuracy_occluded': acc_occ,
+ 'visibility/n_matched': int(len(targets)),
+ 'visibility/n_visible': int(vis_mask.sum()),
+ 'visibility/n_occluded': int(occ_mask.sum()),
+ }
+
+ def evaluate(
+ self,
+ results,
+ eval_mode,
+ metric=None,
+ logger=None,
+ jsonfile_prefix=None,
+ result_names=["img_bbox"],
+ show=False,
+ out_dir=None,
+ pipeline=None,
+ ):
+ res_path = "results.pkl" if "trainval" in self.version else "results_mini.pkl"
+ res_path = osp.join(self.work_dir, res_path)
+ print('All Results write to', res_path)
+ mmcv.dump(results, res_path)
+
+ results_dict = dict()
+ detection_result_files = None
+ if eval_mode['with_det']:
+ self.tracking = eval_mode["with_tracking"]
+ self.tracking_threshold = eval_mode["tracking_threshold"]
+ for metric in ["detection", "tracking"]:
+ tracking = metric == "tracking"
+ if tracking and not self.tracking:
+ continue
+ result_files, tmp_dir = self.format_results(
+ results, jsonfile_prefix=self.work_dir, tracking=tracking
+ )
+ if not tracking:
+ detection_result_files = result_files
+
+ if isinstance(result_files, dict):
+ for name in result_names:
+ ret_dict = self._evaluate_single(
+ result_files[name], tracking=tracking
+ )
+ results_dict.update(ret_dict)
+ elif isinstance(result_files, str):
+ ret_dict = self._evaluate_single(
+ result_files, tracking=tracking
+ )
+ results_dict.update(ret_dict)
+
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+
+ vis_metrics = self._evaluate_visibility_accuracy(results)
+ results_dict.update(vis_metrics)
+
+ if eval_mode['with_map']:
+ from .evaluation.map.vector_eval import VectorEvaluate
+ self.map_evaluator = VectorEvaluate(self.eval_config)
+ result_path = self.format_map_results(results, prefix=self.work_dir)
+ map_results_dict = self.map_evaluator.evaluate(result_path, logger=logger)
+ results_dict.update(map_results_dict)
+
+ motion_result_files = None
+ if eval_mode['with_motion']:
+ thresh = eval_mode["motion_threshhold"]
+ motion_result_files = self.format_motion_results(results, jsonfile_prefix=self.work_dir, thresh=thresh)
+ motion_results_dict = self._evaluate_single_motion(motion_result_files, self.work_dir, logger=logger)
+ results_dict.update(motion_results_dict)
+
+ if eval_mode.get('with_occlusion', False):
+ thresh = eval_mode["motion_threshhold"]
+ if motion_result_files is None:
+ motion_result_files = self.format_motion_results(results, jsonfile_prefix=self.work_dir, thresh=thresh)
+ occluded_results_dict = self._evaluate_single_motion_occluded(
+ motion_result_files, self.work_dir, logger=logger)
+ results_dict.update(occluded_results_dict)
+
+ if detection_result_files is not None:
+ if isinstance(detection_result_files, dict):
+ for name in result_names:
+ vis_det_dict = self._evaluate_single_det_visible(
+ detection_result_files[name], logger=logger, result_name=name)
+ results_dict.update(vis_det_dict)
+ occ_det_dict = self._evaluate_single_det_occluded(
+ detection_result_files[name], logger=logger, result_name=name)
+ results_dict.update(occ_det_dict)
+ all_det_dict = self._evaluate_single_det_all(
+ detection_result_files[name], logger=logger, result_name=name)
+ results_dict.update(all_det_dict)
+ elif isinstance(detection_result_files, str):
+ vis_det_dict = self._evaluate_single_det_visible(
+ detection_result_files, logger=logger)
+ results_dict.update(vis_det_dict)
+ occ_det_dict = self._evaluate_single_det_occluded(
+ detection_result_files, logger=logger)
+ results_dict.update(occ_det_dict)
+ all_det_dict = self._evaluate_single_det_all(
+ detection_result_files, logger=logger)
+ results_dict.update(all_det_dict)
+
+ all_results_dict = self._evaluate_single_motion_all(
+ motion_result_files, self.work_dir, logger=logger)
+ results_dict.update(all_results_dict)
+
+ if eval_mode['with_planning']:
+ from .evaluation.planning.planning_eval import planning_eval
+ planning_results_dict = planning_eval(
+ results, self.eval_config, logger=logger,
+ with_occlusion=eval_mode.get('with_occlusion', False))
+ results_dict.update(planning_results_dict)
+
+ if show or out_dir:
+ self.show(results, save_dir=out_dir, show=show, pipeline=pipeline)
+
+ # print main metrics for recording
+ metric_str = '\n'
+ if "img_bbox_NuScenes/NDS" in results_dict:
+ metric_str += f'mAP: {results_dict.get("img_bbox_NuScenes/mAP"):.4f}\n'
+ metric_str += f'mATE: {results_dict.get("img_bbox_NuScenes/mATE"):.4f}\n'
+ metric_str += f'mASE: {results_dict.get("img_bbox_NuScenes/mASE"):.4f}\n'
+ metric_str += f'mAOE: {results_dict.get("img_bbox_NuScenes/mAOE"):.4f}\n'
+ metric_str += f'mAVE: {results_dict.get("img_bbox_NuScenes/mAVE"):.4f}\n'
+ metric_str += f'mAAE: {results_dict.get("img_bbox_NuScenes/mAAE"):.4f}\n'
+ metric_str += f'NDS: {results_dict.get("img_bbox_NuScenes/NDS"):.4f}\n\n'
+
+ if "img_bbox_NuScenes/amota" in results_dict:
+ metric_str += f'AMOTA: {results_dict["img_bbox_NuScenes/amota"]:.4f}\n'
+ metric_str += f'AMOTP: {results_dict["img_bbox_NuScenes/amotp"]:.4f}\n'
+ metric_str += f'RECALL: {results_dict["img_bbox_NuScenes/recall"]:.4f}\n'
+ metric_str += f'MOTAR: {results_dict["img_bbox_NuScenes/motar"]:.4f}\n'
+ metric_str += f'MOTA: {results_dict["img_bbox_NuScenes/mota"]:.4f}\n'
+ metric_str += f'MOTP: {results_dict["img_bbox_NuScenes/motp"]:.4f}\n'
+ metric_str += f'IDS: {results_dict["img_bbox_NuScenes/ids"]}\n\n'
+
+ if "mAP_normal" in results_dict:
+ metric_str += f'ped_crossing= {results_dict["ped_crossing"]:.4f}\n'
+ metric_str += f'divider= {results_dict["divider"]:.4f}\n'
+ metric_str += f'boundary= {results_dict["boundary"]:.4f}\n'
+ metric_str += f'mAP_normal= {results_dict["mAP_normal"]:.4f}\n\n'
+
+ if "car_EPA" in results_dict:
+ metric_str += f'Car / Ped\n'
+ metric_str += f'epa= {results_dict["car_EPA"]:.4f} / {results_dict["pedestrian_EPA"]:.4f}\n'
+ metric_str += f'ade= {results_dict["car_min_ade_err"]:.4f} / {results_dict["pedestrian_min_ade_err"]:.4f}\n'
+ metric_str += f'fde= {results_dict["car_min_fde_err"]:.4f} / {results_dict["pedestrian_min_fde_err"]:.4f}\n'
+ metric_str += f'mr= {results_dict["car_miss_rate_err"]:.4f} / {results_dict["pedestrian_miss_rate_err"]:.4f}\n\n'
+
+ if "occluded/NDS" in results_dict:
+ metric_str += f'[Occluded Det]\n'
+ metric_str += f'mAP: {results_dict["occluded/mAP"]:.4f}\n'
+ metric_str += f'mATE: {results_dict["occluded/mATE"]:.4f}\n'
+ metric_str += f'mASE: {results_dict["occluded/mASE"]:.4f}\n'
+ metric_str += f'mAOE: {results_dict["occluded/mAOE"]:.4f}\n'
+ metric_str += f'mAVE: {results_dict["occluded/mAVE"]:.4f}\n'
+ metric_str += f'mAAE: {results_dict["occluded/mAAE"]:.4f}\n'
+ metric_str += f'NDS: {results_dict["occluded/NDS"]:.4f}\n\n'
+
+ if "occluded/car_EPA" in results_dict:
+ metric_str += f'[Occluded Motion] Car / Ped\n'
+ metric_str += f'epa= {results_dict["occluded/car_EPA"]:.4f} / {results_dict["occluded/pedestrian_EPA"]:.4f}\n'
+ metric_str += f'ade= {results_dict["occluded/car_min_ade_err"]:.4f} / {results_dict["occluded/pedestrian_min_ade_err"]:.4f}\n'
+ metric_str += f'fde= {results_dict["occluded/car_min_fde_err"]:.4f} / {results_dict["occluded/pedestrian_min_fde_err"]:.4f}\n'
+ metric_str += f'mr= {results_dict["occluded/car_miss_rate_err"]:.4f} / {results_dict["occluded/pedestrian_miss_rate_err"]:.4f}\n\n'
+
+ if "all/NDS" in results_dict:
+ metric_str += f'[All Det]\n'
+ metric_str += f'mAP: {results_dict["all/mAP"]:.4f}\n'
+ metric_str += f'mATE: {results_dict["all/mATE"]:.4f}\n'
+ metric_str += f'mASE: {results_dict["all/mASE"]:.4f}\n'
+ metric_str += f'mAOE: {results_dict["all/mAOE"]:.4f}\n'
+ metric_str += f'mAVE: {results_dict["all/mAVE"]:.4f}\n'
+ metric_str += f'mAAE: {results_dict["all/mAAE"]:.4f}\n'
+ metric_str += f'NDS: {results_dict["all/NDS"]:.4f}\n\n'
+
+ if "all/car_EPA" in results_dict:
+ metric_str += f'[All Motion] Car / Ped\n'
+ metric_str += f'epa= {results_dict["all/car_EPA"]:.4f} / {results_dict["all/pedestrian_EPA"]:.4f}\n'
+ metric_str += f'ade= {results_dict["all/car_min_ade_err"]:.4f} / {results_dict["all/pedestrian_min_ade_err"]:.4f}\n'
+ metric_str += f'fde= {results_dict["all/car_min_fde_err"]:.4f} / {results_dict["all/pedestrian_min_fde_err"]:.4f}\n'
+ metric_str += f'mr= {results_dict["all/car_miss_rate_err"]:.4f} / {results_dict["all/pedestrian_miss_rate_err"]:.4f}\n\n'
+
+ if 'visibility/accuracy' in results_dict:
+ rd = results_dict
+ metric_str += f'[Visibility Head]\n'
+ metric_str += (
+ f'accuracy= {rd["visibility/accuracy"]:.4f} '
+ f'(visible={rd["visibility/accuracy_visible"]:.4f}, '
+ f'occluded={rd["visibility/accuracy_occluded"]:.4f})\n'
+ )
+ metric_str += f'auroc= {rd["visibility/auroc"]:.4f}\n'
+ metric_str += (
+ f'matched: {rd["visibility/n_matched"]} '
+ f'(visible={rd["visibility/n_visible"]}, '
+ f'occluded={rd["visibility/n_occluded"]})\n\n'
+ )
+
+ if "L2" in results_dict:
+ metric_str += f'obj_box_col: {(results_dict["obj_box_col"]*100):.3f}%\n'
+ metric_str += f'L2: {results_dict["L2"]:.4f}\n'
+ if "occluded/obj_box_col" in results_dict:
+ metric_str += f'obj_box_col_occluded: {(results_dict["occluded/obj_box_col"]*100):.3f}%\n'
+ metric_str += '\n'
+
+ print_log(metric_str, logger=logger)
+ return results_dict
+
+ def show(self, results, save_dir=None, show=False, pipeline=None):
+ save_dir = "./" if save_dir is None else save_dir
+ save_dir = os.path.join(save_dir, "visual")
+ print_log(os.path.abspath(save_dir))
+ pipeline = Compose(pipeline)
+ if not os.path.exists(save_dir):
+ os.makedirs(save_dir)
+
+ fourcc = cv2.VideoWriter_fourcc(*"MJPG")
+ videoWriter = None
+
+ for i, result in enumerate(results):
+ if "img_bbox" in result.keys():
+ result = result["img_bbox"]
+ data_info = pipeline(self.get_data_info(i))
+ imgs = []
+
+ raw_imgs = data_info["img"]
+ lidar2img = data_info["img_metas"].data["lidar2img"]
+ pred_bboxes_3d = result["boxes_3d"][
+ result["scores_3d"] > self.vis_score_threshold
+ ]
+ if "instance_ids" in result and self.tracking:
+ color = []
+ for id in result["instance_ids"].cpu().numpy().tolist():
+ color.append(
+ self.ID_COLOR_MAP[int(id % len(self.ID_COLOR_MAP))]
+ )
+ elif "labels_3d" in result:
+ color = []
+ for id in result["labels_3d"].cpu().numpy().tolist():
+ color.append(self.ID_COLOR_MAP[id])
+ else:
+ color = (255, 0, 0)
+
+ # ===== draw boxes_3d to images =====
+ for j, img_origin in enumerate(raw_imgs):
+ img = img_origin.copy()
+ if len(pred_bboxes_3d) != 0:
+ img = draw_lidar_bbox3d_on_img(
+ pred_bboxes_3d,
+ img,
+ lidar2img[j],
+ img_metas=None,
+ color=color,
+ thickness=3,
+ )
+ imgs.append(img)
+
+ # ===== draw boxes_3d to BEV =====
+ bev = draw_lidar_bbox3d_on_bev(
+ pred_bboxes_3d,
+ bev_size=img.shape[0] * 2,
+ color=color,
+ )
+
+ # ===== put text and concat =====
+ for j, name in enumerate(
+ [
+ "front",
+ "front right",
+ "front left",
+ "rear",
+ "rear left",
+ "rear right",
+ ]
+ ):
+ imgs[j] = cv2.rectangle(
+ imgs[j],
+ (0, 0),
+ (440, 80),
+ color=(255, 255, 255),
+ thickness=-1,
+ )
+ w, h = cv2.getTextSize(name, cv2.FONT_HERSHEY_SIMPLEX, 2, 2)[0]
+ text_x = int(220 - w / 2)
+ text_y = int(40 + h / 2)
+
+ imgs[j] = cv2.putText(
+ imgs[j],
+ name,
+ (text_x, text_y),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 2,
+ (0, 0, 0),
+ 2,
+ cv2.LINE_AA,
+ )
+ image = np.concatenate(
+ [
+ np.concatenate([imgs[2], imgs[0], imgs[1]], axis=1),
+ np.concatenate([imgs[5], imgs[3], imgs[4]], axis=1),
+ ],
+ axis=0,
+ )
+ image = np.concatenate([image, bev], axis=1)
+
+ # ===== save video =====
+ if videoWriter is None:
+ videoWriter = cv2.VideoWriter(
+ os.path.join(save_dir, "video.avi"),
+ fourcc,
+ 7,
+ image.shape[:2][::-1],
+ )
+ cv2.imwrite(os.path.join(save_dir, f"{i}.jpg"), image)
+ videoWriter.write(image)
+ videoWriter.release()
+
+
+def output_to_nusc_box(detection, threshold=None):
+ box3d = detection["boxes_3d"]
+ scores = detection["scores_3d"].numpy()
+ labels = detection["labels_3d"].numpy()
+ if "instance_ids" in detection:
+ ids = detection["instance_ids"] # .numpy()
+ if threshold is not None:
+ if "cls_scores" in detection:
+ mask = detection["cls_scores"].numpy() >= threshold
+ else:
+ mask = scores >= threshold
+ box3d = box3d[mask]
+ scores = scores[mask]
+ labels = labels[mask]
+ ids = ids[mask]
+
+ if hasattr(box3d, "gravity_center"):
+ box_gravity_center = box3d.gravity_center.numpy()
+ box_dims = box3d.dims.numpy()
+ nus_box_dims = box_dims[:, [1, 0, 2]]
+ box_yaw = box3d.yaw.numpy()
+ else:
+ box3d = box3d.numpy()
+ box_gravity_center = box3d[..., :3].copy()
+ box_dims = box3d[..., 3:6].copy()
+ nus_box_dims = box_dims[..., [1, 0, 2]]
+ box_yaw = box3d[..., 6].copy()
+
+ # TODO: check whether this is necessary
+ # with dir_offset & dir_limit in the head
+ # box_yaw = -box_yaw - np.pi / 2
+
+ box_list = []
+ for i in range(len(box3d)):
+ quat = pyquaternion.Quaternion(axis=[0, 0, 1], radians=box_yaw[i])
+ if hasattr(box3d, "gravity_center"):
+ velocity = (*box3d.tensor[i, 7:9], 0.0)
+ else:
+ velocity = (*box3d[i, 7:9], 0.0)
+ box = NuScenesBox(
+ box_gravity_center[i],
+ nus_box_dims[i],
+ quat,
+ label=labels[i],
+ score=scores[i],
+ velocity=velocity,
+ )
+ if "instance_ids" in detection:
+ box.token = ids[i]
+ box_list.append(box)
+ return box_list
+
+
+def lidar_nusc_box_to_global(
+ info,
+ boxes,
+ classes,
+ eval_configs,
+ eval_version="detection_cvpr_2019",
+ filter_with_cls_range=True,
+):
+ box_list = []
+ for i, box in enumerate(boxes):
+ # Move box to ego vehicle coord system
+ box.rotate(pyquaternion.Quaternion(info["lidar2ego_rotation"]))
+ box.translate(np.array(info["lidar2ego_translation"]))
+ # filter det in ego.
+ if filter_with_cls_range:
+ cls_range_map = eval_configs.class_range
+ radius = np.linalg.norm(box.center[:2], 2)
+ det_range = cls_range_map[classes[box.label]]
+ if radius > det_range:
+ continue
+ # Move box to global coord system
+ box.rotate(pyquaternion.Quaternion(info["ego2global_rotation"]))
+ box.translate(np.array(info["ego2global_translation"]))
+ box_list.append(box)
+ return box_list
+
+
+def get_T_global(info):
+ lidar2ego = np.eye(4)
+ lidar2ego[:3, :3] = pyquaternion.Quaternion(
+ info["lidar2ego_rotation"]
+ ).rotation_matrix
+ lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"])
+ ego2global = np.eye(4)
+ ego2global[:3, :3] = pyquaternion.Quaternion(
+ info["ego2global_rotation"]
+ ).rotation_matrix
+ ego2global[:3, 3] = np.array(info["ego2global_translation"])
+ return ego2global @ lidar2ego
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/datasets/pipelines/__init__.py b/projects/mmdet3d_plugin/datasets/pipelines/__init__.py
new file mode 100644
index 0000000..ce8bec5
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/pipelines/__init__.py
@@ -0,0 +1,28 @@
+from .transform import (
+ InstanceNameFilter,
+ CircleObjectRangeFilter,
+ NormalizeMultiviewImage,
+ NuScenesSparse4DAdaptor,
+ MultiScaleDepthMapGenerator,
+)
+from .augment import (
+ ResizeCropFlipImage,
+ BBoxRotation,
+ PhotoMetricDistortionMultiViewImage,
+)
+from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile
+from .vectorize import VectorizeMap
+
+__all__ = [
+ "InstanceNameFilter",
+ "ResizeCropFlipImage",
+ "BBoxRotation",
+ "CircleObjectRangeFilter",
+ "MultiScaleDepthMapGenerator",
+ "NormalizeMultiviewImage",
+ "PhotoMetricDistortionMultiViewImage",
+ "NuScenesSparse4DAdaptor",
+ "LoadMultiViewImageFromFiles",
+ "LoadPointsFromFile",
+ "VectorizeMap",
+]
diff --git a/projects/mmdet3d_plugin/datasets/pipelines/augment.py b/projects/mmdet3d_plugin/datasets/pipelines/augment.py
new file mode 100644
index 0000000..ff99e7e
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/pipelines/augment.py
@@ -0,0 +1,233 @@
+import torch
+
+import numpy as np
+from numpy import random
+import mmcv
+from mmdet.datasets.builder import PIPELINES
+from PIL import Image
+
+
+@PIPELINES.register_module()
+class ResizeCropFlipImage(object):
+ def __call__(self, results):
+ aug_config = results.get("aug_config")
+ if aug_config is None:
+ return results
+ imgs = results["img"]
+ N = len(imgs)
+ new_imgs = []
+ for i in range(N):
+ img, mat = self._img_transform(
+ np.uint8(imgs[i]), aug_config,
+ )
+ new_imgs.append(np.array(img).astype(np.float32))
+ results["lidar2img"][i] = mat @ results["lidar2img"][i]
+ if "cam_intrinsic" in results:
+ results["cam_intrinsic"][i][:3, :3] *= aug_config["resize"]
+ # results["cam_intrinsic"][i][:3, :3] = (
+ # mat[:3, :3] @ results["cam_intrinsic"][i][:3, :3]
+ # )
+
+ results["img"] = new_imgs
+ results["img_shape"] = [x.shape[:2] for x in new_imgs]
+ return results
+
+ def _img_transform(self, img, aug_configs):
+ H, W = img.shape[:2]
+ resize = aug_configs.get("resize", 1)
+ resize_dims = (int(W * resize), int(H * resize))
+ crop = aug_configs.get("crop", [0, 0, *resize_dims])
+ flip = aug_configs.get("flip", False)
+ rotate = aug_configs.get("rotate", 0)
+
+ origin_dtype = img.dtype
+ if origin_dtype != np.uint8:
+ min_value = img.min()
+ max_vaule = img.max()
+ scale = 255 / (max_vaule - min_value)
+ img = (img - min_value) * scale
+ img = np.uint8(img)
+ img = Image.fromarray(img)
+ img = img.resize(resize_dims).crop(crop)
+ if flip:
+ img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
+ img = img.rotate(rotate)
+ img = np.array(img).astype(np.float32)
+ if origin_dtype != np.uint8:
+ img = img.astype(np.float32)
+ img = img / scale + min_value
+
+ transform_matrix = np.eye(3)
+ transform_matrix[:2, :2] *= resize
+ transform_matrix[:2, 2] -= np.array(crop[:2])
+ if flip:
+ flip_matrix = np.array(
+ [[-1, 0, crop[2] - crop[0]], [0, 1, 0], [0, 0, 1]]
+ )
+ transform_matrix = flip_matrix @ transform_matrix
+ rotate = rotate / 180 * np.pi
+ rot_matrix = np.array(
+ [
+ [np.cos(rotate), np.sin(rotate), 0],
+ [-np.sin(rotate), np.cos(rotate), 0],
+ [0, 0, 1],
+ ]
+ )
+ rot_center = np.array([crop[2] - crop[0], crop[3] - crop[1]]) / 2
+ rot_matrix[:2, 2] = -rot_matrix[:2, :2] @ rot_center + rot_center
+ transform_matrix = rot_matrix @ transform_matrix
+ extend_matrix = np.eye(4)
+ extend_matrix[:3, :3] = transform_matrix
+ return img, extend_matrix
+
+
+@PIPELINES.register_module()
+class BBoxRotation(object):
+ def __call__(self, results):
+ angle = results["aug_config"]["rotate_3d"]
+ rot_cos = np.cos(angle)
+ rot_sin = np.sin(angle)
+
+ rot_mat = np.array(
+ [
+ [rot_cos, -rot_sin, 0, 0],
+ [rot_sin, rot_cos, 0, 0],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1],
+ ]
+ )
+ rot_mat_inv = np.linalg.inv(rot_mat)
+
+ num_view = len(results["lidar2img"])
+ for view in range(num_view):
+ results["lidar2img"][view] = (
+ results["lidar2img"][view] @ rot_mat_inv
+ )
+ if "lidar2global" in results:
+ results["lidar2global"] = results["lidar2global"] @ rot_mat_inv
+ if "gt_bboxes_3d" in results:
+ results["gt_bboxes_3d"] = self.box_rotate(
+ results["gt_bboxes_3d"], angle
+ )
+ return results
+
+ @staticmethod
+ def box_rotate(bbox_3d, angle):
+ rot_cos = np.cos(angle)
+ rot_sin = np.sin(angle)
+ rot_mat_T = np.array(
+ [[rot_cos, rot_sin, 0], [-rot_sin, rot_cos, 0], [0, 0, 1]]
+ )
+ bbox_3d[:, :3] = bbox_3d[:, :3] @ rot_mat_T
+ bbox_3d[:, 6] += angle
+ if bbox_3d.shape[-1] > 7:
+ vel_dims = bbox_3d[:, 7:].shape[-1]
+ bbox_3d[:, 7:] = bbox_3d[:, 7:] @ rot_mat_T[:vel_dims, :vel_dims]
+ return bbox_3d
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortionMultiViewImage:
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+ 8. randomly swap channels
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(
+ self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18,
+ ):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Result dict with images distorted.
+ """
+ imgs = results["img"]
+ new_imgs = []
+ for img in imgs:
+ assert img.dtype == np.float32, (
+ "PhotoMetricDistortion needs the input image of dtype np.float32,"
+ ' please set "to_float32=True" in "LoadImageFromFile" pipeline'
+ )
+ # random brightness
+ if random.randint(2):
+ delta = random.uniform(
+ -self.brightness_delta, self.brightness_delta
+ )
+ img += delta
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ if random.randint(2):
+ alpha = random.uniform(
+ self.contrast_lower, self.contrast_upper
+ )
+ img *= alpha
+
+ # convert color from BGR to HSV
+ img = mmcv.bgr2hsv(img)
+
+ # random saturation
+ if random.randint(2):
+ img[..., 1] *= random.uniform(
+ self.saturation_lower, self.saturation_upper
+ )
+
+ # random hue
+ if random.randint(2):
+ img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
+ img[..., 0][img[..., 0] > 360] -= 360
+ img[..., 0][img[..., 0] < 0] += 360
+
+ # convert color from HSV to BGR
+ img = mmcv.hsv2bgr(img)
+
+ # random contrast
+ if mode == 0:
+ if random.randint(2):
+ alpha = random.uniform(
+ self.contrast_lower, self.contrast_upper
+ )
+ img *= alpha
+
+ # randomly swap channels
+ if random.randint(2):
+ img = img[..., random.permutation(3)]
+ new_imgs.append(img)
+ results["img"] = new_imgs
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(\nbrightness_delta={self.brightness_delta},\n"
+ repr_str += "contrast_range="
+ repr_str += f"{(self.contrast_lower, self.contrast_upper)},\n"
+ repr_str += "saturation_range="
+ repr_str += f"{(self.saturation_lower, self.saturation_upper)},\n"
+ repr_str += f"hue_delta={self.hue_delta})"
+ return repr_str
diff --git a/projects/mmdet3d_plugin/datasets/pipelines/loading.py b/projects/mmdet3d_plugin/datasets/pipelines/loading.py
new file mode 100644
index 0000000..cb743ec
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/pipelines/loading.py
@@ -0,0 +1,188 @@
+import numpy as np
+import mmcv
+from mmdet.datasets.builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadMultiViewImageFromFiles(object):
+ """Load multi channel images from a list of separate channel files.
+
+ Expects results['img_filename'] to be a list of filenames.
+
+ Args:
+ to_float32 (bool, optional): Whether to convert the img to float32.
+ Defaults to False.
+ color_type (str, optional): Color type of the file.
+ Defaults to 'unchanged'.
+ """
+
+ def __init__(self, to_float32=False, color_type="unchanged"):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+
+ def __call__(self, results):
+ """Call function to load multi-view image from files.
+
+ Args:
+ results (dict): Result dict containing multi-view image filenames.
+
+ Returns:
+ dict: The result dict containing the multi-view image data.
+ Added keys and values are described below.
+
+ - filename (str): Multi-view image filenames.
+ - img (np.ndarray): Multi-view image arrays.
+ - img_shape (tuple[int]): Shape of multi-view image arrays.
+ - ori_shape (tuple[int]): Shape of original image arrays.
+ - pad_shape (tuple[int]): Shape of padded image arrays.
+ - scale_factor (float): Scale factor.
+ - img_norm_cfg (dict): Normalization configuration of images.
+ """
+ filename = results["img_filename"]
+ # img is of shape (h, w, c, num_views)
+ img = np.stack(
+ [mmcv.imread(name, self.color_type) for name in filename], axis=-1
+ )
+ if self.to_float32:
+ img = img.astype(np.float32)
+ results["filename"] = filename
+ # unravel to list, see `DefaultFormatBundle` in formatting.py
+ # which will transpose each image separately and then stack into array
+ results["img"] = [img[..., i] for i in range(img.shape[-1])]
+ results["img_shape"] = img.shape
+ results["ori_shape"] = img.shape
+ # Set initial values for default meta_keys
+ results["pad_shape"] = img.shape
+ results["scale_factor"] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results["img_norm_cfg"] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False,
+ )
+ return results
+
+ def __repr__(self):
+ """str: Return a string that describes the module."""
+ repr_str = self.__class__.__name__
+ repr_str += f"(to_float32={self.to_float32}, "
+ repr_str += f"color_type='{self.color_type}')"
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadPointsFromFile(object):
+ """Load Points From File.
+
+ Load points from file.
+
+ Args:
+ coord_type (str): The type of coordinates of points cloud.
+ Available options includes:
+ - 'LIDAR': Points in LiDAR coordinates.
+ - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
+ - 'CAMERA': Points in camera coordinates.
+ load_dim (int, optional): The dimension of the loaded points.
+ Defaults to 6.
+ use_dim (list[int], optional): Which dimensions of the points to use.
+ Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
+ or use_dim=[0, 1, 2, 3] to use the intensity dimension.
+ shift_height (bool, optional): Whether to use shifted height.
+ Defaults to False.
+ use_color (bool, optional): Whether to use color features.
+ Defaults to False.
+ file_client_args (dict, optional): Config dict of file clients,
+ refer to
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
+ for more details. Defaults to dict(backend='disk').
+ """
+
+ def __init__(
+ self,
+ coord_type,
+ load_dim=6,
+ use_dim=[0, 1, 2],
+ shift_height=False,
+ use_color=False,
+ file_client_args=dict(backend="disk"),
+ ):
+ self.shift_height = shift_height
+ self.use_color = use_color
+ if isinstance(use_dim, int):
+ use_dim = list(range(use_dim))
+ assert (
+ max(use_dim) < load_dim
+ ), f"Expect all used dimensions < {load_dim}, got {use_dim}"
+ assert coord_type in ["CAMERA", "LIDAR", "DEPTH"]
+
+ self.coord_type = coord_type
+ self.load_dim = load_dim
+ self.use_dim = use_dim
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def _load_points(self, pts_filename):
+ """Private function to load point clouds data.
+
+ Args:
+ pts_filename (str): Filename of point clouds data.
+
+ Returns:
+ np.ndarray: An array containing point clouds data.
+ """
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+ try:
+ pts_bytes = self.file_client.get(pts_filename)
+ points = np.frombuffer(pts_bytes, dtype=np.float32)
+ except ConnectionError:
+ mmcv.check_file_exist(pts_filename)
+ if pts_filename.endswith(".npy"):
+ points = np.load(pts_filename)
+ else:
+ points = np.fromfile(pts_filename, dtype=np.float32)
+
+ return points
+
+ def __call__(self, results):
+ """Call function to load points data from file.
+
+ Args:
+ results (dict): Result dict containing point clouds data.
+
+ Returns:
+ dict: The result dict containing the point clouds data.
+ Added key and value are described below.
+
+ - points (:obj:`BasePoints`): Point clouds data.
+ """
+ pts_filename = results["pts_filename"]
+ points = self._load_points(pts_filename)
+ points = points.reshape(-1, self.load_dim)
+ points = points[:, self.use_dim]
+ attribute_dims = None
+
+ if self.shift_height:
+ floor_height = np.percentile(points[:, 2], 0.99)
+ height = points[:, 2] - floor_height
+ points = np.concatenate(
+ [points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1
+ )
+ attribute_dims = dict(height=3)
+
+ if self.use_color:
+ assert len(self.use_dim) >= 6
+ if attribute_dims is None:
+ attribute_dims = dict()
+ attribute_dims.update(
+ dict(
+ color=[
+ points.shape[1] - 3,
+ points.shape[1] - 2,
+ points.shape[1] - 1,
+ ]
+ )
+ )
+
+ results["points"] = points
+ return results
diff --git a/projects/mmdet3d_plugin/datasets/pipelines/transform.py b/projects/mmdet3d_plugin/datasets/pipelines/transform.py
new file mode 100644
index 0000000..4fe79bf
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/pipelines/transform.py
@@ -0,0 +1,250 @@
+import numpy as np
+import mmcv
+from mmcv.parallel import DataContainer as DC
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines import to_tensor
+
+
+@PIPELINES.register_module()
+class MultiScaleDepthMapGenerator(object):
+ def __init__(self, downsample=1, max_depth=60):
+ if not isinstance(downsample, (list, tuple)):
+ downsample = [downsample]
+ self.downsample = downsample
+ self.max_depth = max_depth
+
+ def __call__(self, input_dict):
+ points = input_dict["points"][..., :3, None]
+ gt_depth = []
+ for i, lidar2img in enumerate(input_dict["lidar2img"]):
+ H, W = input_dict["img_shape"][i][:2]
+
+ pts_2d = (
+ np.squeeze(lidar2img[:3, :3] @ points, axis=-1)
+ + lidar2img[:3, 3]
+ )
+ pts_2d[:, :2] /= pts_2d[:, 2:3]
+ U = np.round(pts_2d[:, 0]).astype(np.int32)
+ V = np.round(pts_2d[:, 1]).astype(np.int32)
+ depths = pts_2d[:, 2]
+ mask = np.logical_and.reduce(
+ [
+ V >= 0,
+ V < H,
+ U >= 0,
+ U < W,
+ depths >= 0.1,
+ # depths <= self.max_depth,
+ ]
+ )
+ V, U, depths = V[mask], U[mask], depths[mask]
+ sort_idx = np.argsort(depths)[::-1]
+ V, U, depths = V[sort_idx], U[sort_idx], depths[sort_idx]
+ depths = np.clip(depths, 0.1, self.max_depth)
+ for j, downsample in enumerate(self.downsample):
+ if len(gt_depth) < j + 1:
+ gt_depth.append([])
+ h, w = (int(H / downsample), int(W / downsample))
+ u = np.floor(U / downsample).astype(np.int32)
+ v = np.floor(V / downsample).astype(np.int32)
+ depth_map = np.ones([h, w], dtype=np.float32) * -1
+ depth_map[v, u] = depths
+ gt_depth[j].append(depth_map)
+
+ input_dict["gt_depth"] = [np.stack(x) for x in gt_depth]
+ return input_dict
+
+
+@PIPELINES.register_module()
+class NuScenesSparse4DAdaptor(object):
+ def __init(self):
+ pass
+
+ def __call__(self, input_dict):
+ input_dict["projection_mat"] = np.float32(
+ np.stack(input_dict["lidar2img"])
+ )
+ input_dict["image_wh"] = np.ascontiguousarray(
+ np.array(input_dict["img_shape"], dtype=np.float32)[:, :2][:, ::-1]
+ )
+ input_dict["T_global_inv"] = np.linalg.inv(input_dict["lidar2global"])
+ input_dict["T_global"] = input_dict["lidar2global"]
+ if "cam_intrinsic" in input_dict:
+ input_dict["cam_intrinsic"] = np.float32(
+ np.stack(input_dict["cam_intrinsic"])
+ )
+ input_dict["focal"] = input_dict["cam_intrinsic"][..., 0, 0]
+ if "instance_inds" in input_dict:
+ input_dict["instance_id"] = input_dict["instance_inds"]
+
+ if "gt_bboxes_3d" in input_dict:
+ input_dict["gt_bboxes_3d"][:, 6] = self.limit_period(
+ input_dict["gt_bboxes_3d"][:, 6], offset=0.5, period=2 * np.pi
+ )
+ input_dict["gt_bboxes_3d"] = DC(
+ to_tensor(input_dict["gt_bboxes_3d"]).float()
+ )
+ if "gt_labels_3d" in input_dict:
+ input_dict["gt_labels_3d"] = DC(
+ to_tensor(input_dict["gt_labels_3d"]).long()
+ )
+ if "gt_visibility" in input_dict:
+ input_dict["gt_visibility"] = DC(
+ to_tensor(input_dict["gt_visibility"]).float()
+ )
+
+ imgs = [img.transpose(2, 0, 1) for img in input_dict["img"]]
+ imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
+ input_dict["img"] = DC(to_tensor(imgs), stack=True)
+
+ for key in [
+ 'gt_map_labels',
+ 'gt_map_pts',
+ 'gt_agent_fut_trajs',
+ 'gt_agent_fut_masks',
+ ]:
+ if key not in input_dict:
+ continue
+ input_dict[key] = DC(to_tensor(input_dict[key]), stack=False, cpu_only=False)
+
+ for key in [
+ 'gt_ego_fut_trajs',
+ 'gt_ego_fut_masks',
+ 'gt_ego_fut_cmd',
+ 'ego_status',
+ ]:
+ if key not in input_dict:
+ continue
+ input_dict[key] = DC(to_tensor(input_dict[key]), stack=True, cpu_only=False, pad_dims=None)
+
+ return input_dict
+
+ def limit_period(
+ self, val: np.ndarray, offset: float = 0.5, period: float = np.pi
+ ) -> np.ndarray:
+ limited_val = val - np.floor(val / period + offset) * period
+ return limited_val
+
+
+@PIPELINES.register_module()
+class InstanceNameFilter(object):
+ """Filter GT objects by their names.
+
+ Args:
+ classes (list[str]): List of class names to be kept for training.
+ """
+
+ def __init__(self, classes):
+ self.classes = classes
+ self.labels = list(range(len(self.classes)))
+
+ def __call__(self, input_dict):
+ """Call function to filter objects by their names.
+
+ Args:
+ input_dict (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
+ keys are updated in the result dict.
+ """
+ gt_labels_3d = input_dict["gt_labels_3d"]
+ gt_bboxes_mask = np.array(
+ [n in self.labels for n in gt_labels_3d], dtype=np.bool_
+ )
+ input_dict["gt_bboxes_3d"] = input_dict["gt_bboxes_3d"][gt_bboxes_mask]
+ input_dict["gt_labels_3d"] = input_dict["gt_labels_3d"][gt_bboxes_mask]
+ if "instance_inds" in input_dict:
+ input_dict["instance_inds"] = input_dict["instance_inds"][gt_bboxes_mask]
+ if "gt_agent_fut_trajs" in input_dict:
+ input_dict["gt_agent_fut_trajs"] = input_dict["gt_agent_fut_trajs"][gt_bboxes_mask]
+ input_dict["gt_agent_fut_masks"] = input_dict["gt_agent_fut_masks"][gt_bboxes_mask]
+ if "gt_visibility" in input_dict:
+ input_dict["gt_visibility"] = input_dict["gt_visibility"][gt_bboxes_mask]
+ return input_dict
+
+ def __repr__(self):
+ """str: Return a string that describes the module."""
+ repr_str = self.__class__.__name__
+ repr_str += f"(classes={self.classes})"
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CircleObjectRangeFilter(object):
+ def __init__(
+ self, class_dist_thred=[52.5] * 5 + [31.5] + [42] * 3 + [31.5]
+ ):
+ self.class_dist_thred = class_dist_thred
+
+ def __call__(self, input_dict):
+ gt_bboxes_3d = input_dict["gt_bboxes_3d"]
+ gt_labels_3d = input_dict["gt_labels_3d"]
+ dist = np.sqrt(
+ np.sum(gt_bboxes_3d[:, :2] ** 2, axis=-1)
+ )
+ mask = np.array([False] * len(dist))
+ for label_idx, dist_thred in enumerate(self.class_dist_thred):
+ mask = np.logical_or(
+ mask,
+ np.logical_and(gt_labels_3d == label_idx, dist <= dist_thred),
+ )
+
+ gt_bboxes_3d = gt_bboxes_3d[mask]
+ gt_labels_3d = gt_labels_3d[mask]
+
+ input_dict["gt_bboxes_3d"] = gt_bboxes_3d
+ input_dict["gt_labels_3d"] = gt_labels_3d
+ if "instance_inds" in input_dict:
+ input_dict["instance_inds"] = input_dict["instance_inds"][mask]
+ if "gt_agent_fut_trajs" in input_dict:
+ input_dict["gt_agent_fut_trajs"] = input_dict["gt_agent_fut_trajs"][mask]
+ input_dict["gt_agent_fut_masks"] = input_dict["gt_agent_fut_masks"][mask]
+ if "gt_visibility" in input_dict:
+ input_dict["gt_visibility"] = input_dict["gt_visibility"][mask]
+ return input_dict
+
+ def __repr__(self):
+ """str: Return a string that describes the module."""
+ repr_str = self.__class__.__name__
+ repr_str += f"(class_dist_thred={self.class_dist_thred})"
+ return repr_str
+
+
+@PIPELINES.register_module()
+class NormalizeMultiviewImage(object):
+ """Normalize the image.
+ Added key is "img_norm_cfg".
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ """Call function to normalize images.
+ Args:
+ results (dict): Result dict from loading pipeline.
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+ results["img"] = [
+ mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
+ for img in results["img"]
+ ]
+ results["img_norm_cfg"] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb
+ )
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})"
+ return repr_str
diff --git a/projects/mmdet3d_plugin/datasets/pipelines/vectorize.py b/projects/mmdet3d_plugin/datasets/pipelines/vectorize.py
new file mode 100644
index 0000000..39d8729
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/pipelines/vectorize.py
@@ -0,0 +1,208 @@
+from typing import List, Tuple, Union, Dict
+
+import numpy as np
+from shapely.geometry import LineString
+from numpy.typing import NDArray
+
+from mmcv.parallel import DataContainer as DC
+from mmdet.datasets.builder import PIPELINES
+
+
+@PIPELINES.register_module(force=True)
+class VectorizeMap(object):
+ """Generate vectoized map and put into `semantic_mask` key.
+ Concretely, shapely geometry objects are converted into sample points (ndarray).
+ We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method.
+
+ Args:
+ roi_size (tuple or list): bev range .
+ normalize (bool): whether to normalize points to range (0, 1).
+ coords_dim (int): dimension of point coordinates.
+ simplify (bool): whether to use simpily function. If true, `sample_num` \
+ and `sample_dist` will be ignored.
+ sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore.
+ sample_dist (float): interpolate distance. Set to -1 to ignore.
+ """
+
+ def __init__(self,
+ roi_size: Union[Tuple, List],
+ normalize: bool,
+ coords_dim: int=2,
+ simplify: bool=False,
+ sample_num: int=-1,
+ sample_dist: float=-1,
+ permute: bool=False
+ ):
+ self.coords_dim = coords_dim
+ self.sample_num = sample_num
+ self.sample_dist = sample_dist
+ self.roi_size = np.array(roi_size)
+ self.normalize = normalize
+ self.simplify = simplify
+ self.permute = permute
+
+ if sample_dist > 0:
+ assert sample_num < 0 and not simplify
+ self.sample_fn = self.interp_fixed_dist
+ elif sample_num > 0:
+ assert sample_dist < 0 and not simplify
+ self.sample_fn = self.interp_fixed_num
+ else:
+ assert simplify
+
+ def interp_fixed_num(self, line: LineString) -> NDArray:
+ ''' Interpolate a line to fixed number of points.
+
+ Args:
+ line (LineString): line
+
+ Returns:
+ points (array): interpolated points, shape (N, 2)
+ '''
+
+ distances = np.linspace(0, line.length, self.sample_num)
+ sampled_points = np.array([list(line.interpolate(distance).coords)
+ for distance in distances]).squeeze()
+
+ return sampled_points
+
+ def interp_fixed_dist(self, line: LineString) -> NDArray:
+ ''' Interpolate a line at fixed interval.
+
+ Args:
+ line (LineString): line
+
+ Returns:
+ points (array): interpolated points, shape (N, 2)
+ '''
+
+ distances = list(np.arange(self.sample_dist, line.length, self.sample_dist))
+ # make sure to sample at least two points when sample_dist > line.length
+ distances = [0,] + distances + [line.length,]
+
+ sampled_points = np.array([list(line.interpolate(distance).coords)
+ for distance in distances]).squeeze()
+
+ return sampled_points
+
+ def get_vectorized_lines(self, map_geoms: Dict) -> Dict:
+ ''' Vectorize map elements. Iterate over the input dict and apply the
+ specified sample funcion.
+
+ Args:
+ line (LineString): line
+
+ Returns:
+ vectors (array): dict of vectorized map elements.
+ '''
+
+ vectors = {}
+ for label, geom_list in map_geoms.items():
+ vectors[label] = []
+ for geom in geom_list:
+ if geom.geom_type == 'LineString':
+ if self.simplify:
+ line = geom.simplify(0.2, preserve_topology=True)
+ line = np.array(line.coords)
+ else:
+ line = self.sample_fn(geom)
+ line = line[:, :self.coords_dim]
+
+ if self.normalize:
+ line = self.normalize_line(line)
+ if self.permute:
+ line = self.permute_line(line)
+ vectors[label].append(line)
+
+ elif geom.geom_type == 'Polygon':
+ # polygon objects will not be vectorized
+ continue
+
+ else:
+ raise ValueError('map geoms must be either LineString or Polygon!')
+ return vectors
+
+ def normalize_line(self, line: NDArray) -> NDArray:
+ ''' Convert points to range (0, 1).
+
+ Args:
+ line (LineString): line
+
+ Returns:
+ normalized (array): normalized points.
+ '''
+
+ origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2])
+
+ line[:, :2] = line[:, :2] - origin
+
+ # transform from range [0, 1] to (0, 1)
+ eps = 1e-5
+ line[:, :2] = line[:, :2] / (self.roi_size + eps)
+
+ return line
+
+ def permute_line(self, line: np.ndarray, padding=1e5):
+ '''
+ (num_pts, 2) -> (num_permute, num_pts, 2)
+ where num_permute = 2 * (num_pts - 1)
+ '''
+ is_closed = np.allclose(line[0], line[-1], atol=1e-3)
+ num_points = len(line)
+ permute_num = num_points - 1
+ permute_lines_list = []
+ if is_closed:
+ pts_to_permute = line[:-1, :] # throw away replicate start end pts
+ for shift_i in range(permute_num):
+ permute_lines_list.append(np.roll(pts_to_permute, shift_i, axis=0))
+ flip_pts_to_permute = np.flip(pts_to_permute, axis=0)
+ for shift_i in range(permute_num):
+ permute_lines_list.append(np.roll(flip_pts_to_permute, shift_i, axis=0))
+ else:
+ permute_lines_list.append(line)
+ permute_lines_list.append(np.flip(line, axis=0))
+
+ permute_lines_array = np.stack(permute_lines_list, axis=0)
+
+ if is_closed:
+ tmp = np.zeros((permute_num * 2, num_points, self.coords_dim))
+ tmp[:, :-1, :] = permute_lines_array
+ tmp[:, -1, :] = permute_lines_array[:, 0, :] # add replicate start end pts
+ permute_lines_array = tmp
+
+ else:
+ # padding
+ padding = np.full([permute_num * 2 - 2, num_points, self.coords_dim], padding)
+ permute_lines_array = np.concatenate((permute_lines_array, padding), axis=0)
+
+ return permute_lines_array
+
+ def __call__(self, input_dict):
+ if "map_geoms" not in input_dict:
+ return input_dict
+ map_geoms = input_dict['map_geoms']
+ vectors = self.get_vectorized_lines(map_geoms)
+
+ if self.permute:
+ gt_map_labels, gt_map_pts = [], []
+ for label, vecs in vectors.items():
+ for vec in vecs:
+ gt_map_labels.append(label)
+ gt_map_pts.append(vec)
+ input_dict['gt_map_labels'] = np.array(gt_map_labels, dtype=np.int64)
+ input_dict['gt_map_pts'] = np.array(gt_map_pts, dtype=np.float32).reshape(-1, 2 * (self.sample_num - 1), self.sample_num, self.coords_dim)
+ else:
+ input_dict['vectors'] = DC(vectors, stack=False, cpu_only=True)
+
+ return input_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(simplify={self.simplify}, '
+ repr_str += f'sample_num={self.sample_num}), '
+ repr_str += f'sample_dist={self.sample_dist}), '
+ repr_str += f'roi_size={self.roi_size})'
+ repr_str += f'normalize={self.normalize})'
+ repr_str += f'coords_dim={self.coords_dim})'
+
+ return repr_str
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/datasets/samplers/__init__.py b/projects/mmdet3d_plugin/datasets/samplers/__init__.py
new file mode 100644
index 0000000..17da039
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/samplers/__init__.py
@@ -0,0 +1,6 @@
+from .group_sampler import DistributedGroupSampler
+from .distributed_sampler import DistributedSampler
+from .sampler import SAMPLER, build_sampler
+from .group_in_batch_sampler import (
+ GroupInBatchSampler,
+)
diff --git a/projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py b/projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py
new file mode 100644
index 0000000..3cd9077
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py
@@ -0,0 +1,82 @@
+import math
+
+import torch
+from torch.utils.data import DistributedSampler as _DistributedSampler
+from .sampler import SAMPLER
+
+import pdb
+import sys
+
+
+class ForkedPdb(pdb.Pdb):
+ def interaction(self, *args, **kwargs):
+ _stdin = sys.stdin
+ try:
+ sys.stdin = open("/dev/stdin")
+ pdb.Pdb.interaction(self, *args, **kwargs)
+ finally:
+ sys.stdin = _stdin
+
+
+def set_trace():
+ ForkedPdb().set_trace(sys._getframe().f_back)
+
+
+@SAMPLER.register_module()
+class DistributedSampler(_DistributedSampler):
+ def __init__(
+ self, dataset=None, num_replicas=None, rank=None, shuffle=True, seed=0
+ ):
+ super().__init__(
+ dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle
+ )
+ # for the compatibility from PyTorch 1.3+
+ self.seed = seed if seed is not None else 0
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ assert not self.shuffle
+ if "data_infos" in dir(self.dataset):
+ timestamps = [
+ x["timestamp"] / 1e6 for x in self.dataset.data_infos
+ ]
+ vehicle_idx = [
+ x["lidar_path"].split("/")[-1][:4]
+ if "lidar_path" in x
+ else None
+ for x in self.dataset.data_infos
+ ]
+ else:
+ timestamps = [
+ x["timestamp"] / 1e6
+ for x in self.dataset.datasets[0].data_infos
+ ] * len(self.dataset.datasets)
+ vehicle_idx = [
+ x["lidar_path"].split("/")[-1][:4]
+ if "lidar_path" in x
+ else None
+ for x in self.dataset.datasets[0].data_infos
+ ] * len(self.dataset.datasets)
+
+ sequence_splits = []
+ for i in range(len(timestamps)):
+ if i == 0 or (
+ abs(timestamps[i] - timestamps[i - 1]) > 4
+ or vehicle_idx[i] != vehicle_idx[i - 1]
+ ):
+ sequence_splits.append([i])
+ else:
+ sequence_splits[-1].append(i)
+
+ indices = []
+ perfix_sum = 0
+ split_length = len(self.dataset) // self.num_replicas
+ for i in range(len(sequence_splits)):
+ if perfix_sum >= (self.rank + 1) * split_length:
+ break
+ elif perfix_sum >= self.rank * split_length:
+ indices.extend(sequence_splits[i])
+ perfix_sum += len(sequence_splits[i])
+
+ self.num_samples = len(indices)
+ return iter(indices)
diff --git a/projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py b/projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py
new file mode 100644
index 0000000..62b5cb0
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py
@@ -0,0 +1,178 @@
+# https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py
+import itertools
+import copy
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from mmcv.runner import get_dist_info
+from torch.utils.data.sampler import Sampler
+
+
+# https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157
+def sync_random_seed(seed=None, device="cuda"):
+ """Make sure different ranks share the same seed.
+ All workers must call this function, otherwise it will deadlock.
+ This method is generally used in `DistributedSampler`,
+ because the seed should be identical across all processes
+ in the distributed group.
+ In distributed sampling, different ranks should sample non-overlapped
+ data in the dataset. Therefore, this function is used to make sure that
+ each rank shuffles the data indices in the same order based
+ on the same seed. Then different ranks could use different indices
+ to select non-overlapped data from the same data list.
+ Args:
+ seed (int, Optional): The seed. Default to None.
+ device (str): The device where the seed will be put on.
+ Default to 'cuda'.
+ Returns:
+ int: Seed to be used.
+ """
+ if seed is None:
+ seed = np.random.randint(2**31)
+ assert isinstance(seed, int)
+
+ rank, world_size = get_dist_info()
+
+ if world_size == 1:
+ return seed
+
+ if rank == 0:
+ random_num = torch.tensor(seed, dtype=torch.int32, device=device)
+ else:
+ random_num = torch.tensor(0, dtype=torch.int32, device=device)
+ dist.broadcast(random_num, src=0)
+ return random_num.item()
+
+
+class GroupInBatchSampler(Sampler):
+ """
+ Pardon this horrendous name. Basically, we want every sample to be from its own group.
+ If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on
+ its own group.
+
+ Shuffling is only done for group order, not done within groups.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ batch_size=1,
+ world_size=None,
+ rank=None,
+ seed=0,
+ skip_prob=0.,
+ sequence_flip_prob=0.,
+ ):
+ _rank, _world_size = get_dist_info()
+ if world_size is None:
+ world_size = _world_size
+ if rank is None:
+ rank = _rank
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.world_size = world_size
+ self.rank = rank
+ self.seed = sync_random_seed(seed)
+
+ self.size = len(self.dataset)
+
+ assert hasattr(self.dataset, "flag")
+ self.flag = self.dataset.flag
+ self.group_sizes = np.bincount(self.flag)
+ self.groups_num = len(self.group_sizes)
+ self.global_batch_size = batch_size * world_size
+ assert self.groups_num >= self.global_batch_size
+
+ # Now, for efficiency, make a dict group_idx: List[dataset sample_idxs]
+ self.group_idx_to_sample_idxs = {
+ group_idx: np.where(self.flag == group_idx)[0].tolist()
+ for group_idx in range(self.groups_num)
+ }
+
+ # Get a generator per sample idx. Considering samples over all
+ # GPUs, each sample position has its own generator
+ self.group_indices_per_global_sample_idx = [
+ self._group_indices_per_global_sample_idx(
+ self.rank * self.batch_size + local_sample_idx
+ )
+ for local_sample_idx in range(self.batch_size)
+ ]
+
+ # Keep track of a buffer of dataset sample idxs for each local sample idx
+ self.buffer_per_local_sample = [[] for _ in range(self.batch_size)]
+ self.aug_per_local_sample = [None for _ in range(self.batch_size)]
+ self.skip_prob = skip_prob
+ self.sequence_flip_prob = sequence_flip_prob
+
+ def _infinite_group_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self.seed)
+ while True:
+ yield from torch.randperm(self.groups_num, generator=g).tolist()
+
+ def _group_indices_per_global_sample_idx(self, global_sample_idx):
+ yield from itertools.islice(
+ self._infinite_group_indices(),
+ global_sample_idx,
+ None,
+ self.global_batch_size,
+ )
+
+ def __iter__(self):
+ while True:
+ curr_batch = []
+ for local_sample_idx in range(self.batch_size):
+ skip = (
+ np.random.uniform() < self.skip_prob
+ and len(self.buffer_per_local_sample[local_sample_idx]) > 1
+ )
+ if len(self.buffer_per_local_sample[local_sample_idx]) == 0:
+ # Finished current group, refill with next group
+ # skip = False
+ new_group_idx = next(
+ self.group_indices_per_global_sample_idx[
+ local_sample_idx
+ ]
+ )
+ self.buffer_per_local_sample[
+ local_sample_idx
+ ] = copy.deepcopy(
+ self.group_idx_to_sample_idxs[new_group_idx]
+ )
+ if np.random.uniform() < self.sequence_flip_prob:
+ self.buffer_per_local_sample[
+ local_sample_idx
+ ] = self.buffer_per_local_sample[local_sample_idx][
+ ::-1
+ ]
+ if self.dataset.keep_consistent_seq_aug:
+ self.aug_per_local_sample[
+ local_sample_idx
+ ] = self.dataset.get_augmentation()
+
+ if not self.dataset.keep_consistent_seq_aug:
+ self.aug_per_local_sample[
+ local_sample_idx
+ ] = self.dataset.get_augmentation()
+
+ if skip:
+ self.buffer_per_local_sample[local_sample_idx].pop(0)
+ curr_batch.append(
+ dict(
+ idx=self.buffer_per_local_sample[local_sample_idx].pop(
+ 0
+ ),
+ aug_config=self.aug_per_local_sample[local_sample_idx],
+ )
+ )
+
+ yield curr_batch
+
+ def __len__(self):
+ """Length of base dataset."""
+ return self.size
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/projects/mmdet3d_plugin/datasets/samplers/group_sampler.py b/projects/mmdet3d_plugin/datasets/samplers/group_sampler.py
new file mode 100644
index 0000000..a14e114
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/samplers/group_sampler.py
@@ -0,0 +1,119 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+
+import numpy as np
+import torch
+from mmcv.runner import get_dist_info
+from torch.utils.data import Sampler
+from .sampler import SAMPLER
+import random
+from IPython import embed
+
+
+@SAMPLER.register_module()
+class DistributedGroupSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ seed (int, optional): random seed used to shuffle the sampler if
+ ``shuffle=True``. This number should be identical across all
+ processes in the distributed group. Default: 0.
+ """
+
+ def __init__(
+ self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0
+ ):
+ _rank, _num_replicas = get_dist_info()
+ if num_replicas is None:
+ num_replicas = _num_replicas
+ if rank is None:
+ rank = _rank
+ self.dataset = dataset
+ self.samples_per_gpu = samples_per_gpu
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.seed = seed if seed is not None else 0
+
+ assert hasattr(self.dataset, "flag")
+ self.flag = self.dataset.flag
+ self.group_sizes = np.bincount(self.flag)
+
+ self.num_samples = 0
+ for i, j in enumerate(self.group_sizes):
+ self.num_samples += (
+ int(
+ math.ceil(
+ self.group_sizes[i]
+ * 1.0
+ / self.samples_per_gpu
+ / self.num_replicas
+ )
+ )
+ * self.samples_per_gpu
+ )
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch + self.seed)
+
+ indices = []
+ for i, size in enumerate(self.group_sizes):
+ if size > 0:
+ indice = np.where(self.flag == i)[0]
+ assert len(indice) == size
+ # add .numpy() to avoid bug when selecting indice in parrots.
+ # TODO: check whether torch.randperm() can be replaced by
+ # numpy.random.permutation().
+ indice = indice[
+ list(torch.randperm(int(size), generator=g).numpy())
+ ].tolist()
+ extra = int(
+ math.ceil(
+ size * 1.0 / self.samples_per_gpu / self.num_replicas
+ )
+ ) * self.samples_per_gpu * self.num_replicas - len(indice)
+ # pad indice
+ tmp = indice.copy()
+ for _ in range(extra // size):
+ indice.extend(tmp)
+ indice.extend(tmp[: extra % size])
+ indices.extend(indice)
+
+ assert len(indices) == self.total_size
+
+ indices = [
+ indices[j]
+ for i in list(
+ torch.randperm(
+ len(indices) // self.samples_per_gpu, generator=g
+ )
+ )
+ for j in range(
+ i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu
+ )
+ ]
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset : offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/projects/mmdet3d_plugin/datasets/samplers/sampler.py b/projects/mmdet3d_plugin/datasets/samplers/sampler.py
new file mode 100644
index 0000000..9bfc443
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/samplers/sampler.py
@@ -0,0 +1,7 @@
+from mmcv.utils.registry import Registry, build_from_cfg
+
+SAMPLER = Registry("sampler")
+
+
+def build_sampler(cfg, default_args):
+ return build_from_cfg(cfg, SAMPLER, default_args)
diff --git a/projects/mmdet3d_plugin/datasets/utils.py b/projects/mmdet3d_plugin/datasets/utils.py
new file mode 100644
index 0000000..bf24b2f
--- /dev/null
+++ b/projects/mmdet3d_plugin/datasets/utils.py
@@ -0,0 +1,225 @@
+import copy
+
+import cv2
+import numpy as np
+import torch
+
+from projects.mmdet3d_plugin.core.box3d import *
+
+
+def box3d_to_corners(box3d):
+ if isinstance(box3d, torch.Tensor):
+ box3d = box3d.detach().cpu().numpy()
+ corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)
+ corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
+ # use relative origin [0.5, 0.5, 0]
+ corners_norm = corners_norm - np.array([0.5, 0.5, 0.5])
+ corners = box3d[:, None, [W, L, H]] * corners_norm.reshape([1, 8, 3])
+
+ # rotate around z axis
+ rot_cos = np.cos(box3d[:, YAW])
+ rot_sin = np.sin(box3d[:, YAW])
+ rot_mat = np.tile(np.eye(3)[None], (box3d.shape[0], 1, 1))
+ rot_mat[:, 0, 0] = rot_cos
+ rot_mat[:, 0, 1] = -rot_sin
+ rot_mat[:, 1, 0] = rot_sin
+ rot_mat[:, 1, 1] = rot_cos
+ corners = (rot_mat[:, None] @ corners[..., None]).squeeze(axis=-1)
+ corners += box3d[:, None, :3]
+ return corners
+
+
+def plot_rect3d_on_img(
+ img, num_rects, rect_corners, color=(0, 255, 0), thickness=1
+):
+ """Plot the boundary lines of 3D rectangular on 2D images.
+
+ Args:
+ img (numpy.array): The numpy array of image.
+ num_rects (int): Number of 3D rectangulars.
+ rect_corners (numpy.array): Coordinates of the corners of 3D
+ rectangulars. Should be in the shape of [num_rect, 8, 2].
+ color (tuple[int], optional): The color to draw bboxes.
+ Default: (0, 255, 0).
+ thickness (int, optional): The thickness of bboxes. Default: 1.
+ """
+ line_indices = (
+ (0, 1),
+ (0, 3),
+ (0, 4),
+ (1, 2),
+ (1, 5),
+ (3, 2),
+ (3, 7),
+ (4, 5),
+ (4, 7),
+ (2, 6),
+ (5, 6),
+ (6, 7),
+ )
+ h, w = img.shape[:2]
+ for i in range(num_rects):
+ corners = np.clip(rect_corners[i], -1e4, 1e5).astype(np.int32)
+ for start, end in line_indices:
+ if (
+ (corners[start, 1] >= h or corners[start, 1] < 0)
+ or (corners[start, 0] >= w or corners[start, 0] < 0)
+ ) and (
+ (corners[end, 1] >= h or corners[end, 1] < 0)
+ or (corners[end, 0] >= w or corners[end, 0] < 0)
+ ):
+ continue
+ if isinstance(color[0], int):
+ cv2.line(
+ img,
+ (corners[start, 0], corners[start, 1]),
+ (corners[end, 0], corners[end, 1]),
+ color,
+ thickness,
+ cv2.LINE_AA,
+ )
+ else:
+ cv2.line(
+ img,
+ (corners[start, 0], corners[start, 1]),
+ (corners[end, 0], corners[end, 1]),
+ color[i],
+ thickness,
+ cv2.LINE_AA,
+ )
+
+ return img.astype(np.uint8)
+
+
+def draw_lidar_bbox3d_on_img(
+ bboxes3d, raw_img, lidar2img_rt, img_metas=None, color=(0, 255, 0), thickness=1
+):
+ """Project the 3D bbox on 2D plane and draw on input image.
+
+ Args:
+ bboxes3d (:obj:`LiDARInstance3DBoxes`):
+ 3d bbox in lidar coordinate system to visualize.
+ raw_img (numpy.array): The numpy array of image.
+ lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix
+ according to the camera intrinsic parameters.
+ img_metas (dict): Useless here.
+ color (tuple[int], optional): The color to draw bboxes.
+ Default: (0, 255, 0).
+ thickness (int, optional): The thickness of bboxes. Default: 1.
+ """
+ img = raw_img.copy()
+ # corners_3d = bboxes3d.corners
+ corners_3d = box3d_to_corners(bboxes3d)
+ num_bbox = corners_3d.shape[0]
+ pts_4d = np.concatenate(
+ [corners_3d.reshape(-1, 3), np.ones((num_bbox * 8, 1))], axis=-1
+ )
+ lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
+ if isinstance(lidar2img_rt, torch.Tensor):
+ lidar2img_rt = lidar2img_rt.cpu().numpy()
+ pts_2d = pts_4d @ lidar2img_rt.T
+
+ pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5)
+ pts_2d[:, 0] /= pts_2d[:, 2]
+ pts_2d[:, 1] /= pts_2d[:, 2]
+ imgfov_pts_2d = pts_2d[..., :2].reshape(num_bbox, 8, 2)
+
+ return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness)
+
+
+def draw_points_on_img(points, img, lidar2img_rt, color=(0, 255, 0), circle=4):
+ img = img.copy()
+ N = points.shape[0]
+ points = points.cpu().numpy()
+ lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
+ if isinstance(lidar2img_rt, torch.Tensor):
+ lidar2img_rt = lidar2img_rt.cpu().numpy()
+ pts_2d = (
+ np.sum(points[:, :, None] * lidar2img_rt[:3, :3], axis=-1)
+ + lidar2img_rt[:3, 3]
+ )
+ pts_2d[..., 2] = np.clip(pts_2d[..., 2], a_min=1e-5, a_max=1e5)
+ pts_2d = pts_2d[..., :2] / pts_2d[..., 2:3]
+ pts_2d = np.clip(pts_2d, -1e4, 1e4).astype(np.int32)
+
+ for i in range(N):
+ for point in pts_2d[i]:
+ if isinstance(color[0], int):
+ color_tmp = color
+ else:
+ color_tmp = color[i]
+ cv2.circle(img, point.tolist(), circle, color_tmp, thickness=-1)
+ return img.astype(np.uint8)
+
+
+def draw_lidar_bbox3d_on_bev(
+ bboxes_3d, bev_size, bev_range=115, color=(255, 0, 0), thickness=3):
+ if isinstance(bev_size, (list, tuple)):
+ bev_h, bev_w = bev_size
+ else:
+ bev_h, bev_w = bev_size, bev_size
+ bev = np.zeros([bev_h, bev_w, 3])
+
+ marking_color = (127, 127, 127)
+ bev_resolution = bev_range / bev_h
+ for cir in range(int(bev_range / 2 / 10)):
+ cv2.circle(
+ bev,
+ (int(bev_h / 2), int(bev_w / 2)),
+ int((cir + 1) * 10 / bev_resolution),
+ marking_color,
+ thickness=thickness,
+ )
+ cv2.line(
+ bev,
+ (0, int(bev_h / 2)),
+ (bev_w, int(bev_h / 2)),
+ marking_color,
+ )
+ cv2.line(
+ bev,
+ (int(bev_w / 2), 0),
+ (int(bev_w / 2), bev_h),
+ marking_color,
+ )
+ if len(bboxes_3d) != 0:
+ bev_corners = box3d_to_corners(bboxes_3d)[:, [0, 3, 4, 7]][
+ ..., [0, 1]
+ ]
+ xs = bev_corners[..., 0] / bev_resolution + bev_w / 2
+ ys = -bev_corners[..., 1] / bev_resolution + bev_h / 2
+ for obj_idx, (x, y) in enumerate(zip(xs, ys)):
+ for p1, p2 in ((0, 1), (0, 2), (1, 3), (2, 3)):
+ if isinstance(color[0], (list, tuple)):
+ tmp = color[obj_idx]
+ else:
+ tmp = color
+ cv2.line(
+ bev,
+ (int(x[p1]), int(y[p1])),
+ (int(x[p2]), int(y[p2])),
+ tmp,
+ thickness=thickness,
+ )
+ return bev.astype(np.uint8)
+
+
+def draw_lidar_bbox3d(bboxes_3d, imgs, lidar2imgs, color=(255, 0, 0)):
+ vis_imgs = []
+ for i, (img, lidar2img) in enumerate(zip(imgs, lidar2imgs)):
+ vis_imgs.append(
+ draw_lidar_bbox3d_on_img(bboxes_3d, img, lidar2img, color=color)
+ )
+
+ num_imgs = len(vis_imgs)
+ if num_imgs < 4 or num_imgs % 2 != 0:
+ vis_imgs = np.concatenate(vis_imgs, axis=1)
+ else:
+ vis_imgs = np.concatenate([
+ np.concatenate(vis_imgs[:num_imgs//2], axis=1),
+ np.concatenate(vis_imgs[num_imgs//2:], axis=1)
+ ], axis=0)
+
+ bev = draw_lidar_bbox3d_on_bev(bboxes_3d, vis_imgs.shape[0], color=color)
+ vis_imgs = np.concatenate([bev, vis_imgs], axis=1)
+ return vis_imgs
diff --git a/projects/mmdet3d_plugin/models/__init__.py b/projects/mmdet3d_plugin/models/__init__.py
new file mode 100644
index 0000000..eb86c67
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/__init__.py
@@ -0,0 +1,34 @@
+from .sparsedrive import SparseDrive
+from .sparsedrive_head import SparseDriveHead
+from .gt_sparse_drive_head import GTSparseDriveHead
+from .blocks import (
+ DeformableFeatureAggregation,
+ DenseDepthNet,
+ AsymmetricFFN,
+)
+from .instance_bank import InstanceBank
+from .detection3d import (
+ SparseBox3DDecoder,
+ SparseBox3DTarget,
+ SparseBox3DRefinementModule,
+ SparseBox3DKeyPointsGenerator,
+ SparseBox3DEncoder,
+)
+from .map import *
+from .motion import *
+
+
+__all__ = [
+ "SparseDrive",
+ "SparseDriveHead",
+ "GTSparseDriveHead",
+ "DeformableFeatureAggregation",
+ "DenseDepthNet",
+ "AsymmetricFFN",
+ "InstanceBank",
+ "SparseBox3DDecoder",
+ "SparseBox3DTarget",
+ "SparseBox3DRefinementModule",
+ "SparseBox3DKeyPointsGenerator",
+ "SparseBox3DEncoder",
+]
diff --git a/projects/mmdet3d_plugin/models/attention.py b/projects/mmdet3d_plugin/models/attention.py
new file mode 100644
index 0000000..9afe1ec
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/attention.py
@@ -0,0 +1,319 @@
+import warnings
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn.functional import linear
+from torch.nn.init import xavier_uniform_, constant_
+
+from mmcv.utils import deprecated_api_warning
+from mmcv.runner import auto_fp16
+from mmcv.runner.base_module import BaseModule
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.bricks.registry import ATTENTION
+import torch.utils.checkpoint as cp
+
+
+from einops import rearrange
+try:
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
+ print('Use flash_attn_unpadded_kvpacked_func')
+except:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func
+ print('Use flash_attn_varlen_kvpacked_func')
+from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
+
+
+def _in_projection_packed(q, k, v, w, b = None):
+ w_q, w_k, w_v = w.chunk(3)
+ if b is None:
+ b_q = b_k = b_v = None
+ else:
+ b_q, b_k, b_v = b.chunk(3)
+ return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
+
+
+class FlashAttention(nn.Module):
+ """Implement the scaled dot product attention with softmax.
+ Arguments
+ ---------
+ softmax_scale: The temperature to use for the softmax attention.
+ (default: 1/sqrt(d_keys) where d_keys is computed at
+ runtime)
+ attention_dropout: The dropout rate to apply to the attention
+ (default: 0.1)
+ """
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
+ super().__init__()
+ self.softmax_scale = softmax_scale
+ self.dropout_p = attention_dropout
+ self.fp16_enabled = True
+
+ @auto_fp16(apply_to=('q', 'kv'), out_fp32=True)
+ def forward(self, q, kv,
+ causal=False,
+ key_padding_mask=None):
+ """Implements the multihead softmax attention.
+ Arguments
+ ---------
+ q: The tensor containing the query. (B, T, H, D)
+ kv: The tensor containing the key, and value. (B, S, 2, H, D)
+ key_padding_mask: a bool tensor of shape (B, S)
+ """
+ assert q.dtype in [torch.float16, torch.bfloat16] and kv.dtype in [torch.float16, torch.bfloat16]
+ assert q.is_cuda and kv.is_cuda
+ assert q.shape[0] == kv.shape[0] and q.shape[-2] == kv.shape[-2] and q.shape[-1] == kv.shape[-1]
+
+ batch_size = q.shape[0]
+ seqlen_q, seqlen_k = q.shape[1], kv.shape[1]
+ if key_padding_mask is None:
+ q, kv = rearrange(q, 'b s ... -> (b s) ...'), rearrange(kv, 'b s ... -> (b s) ...')
+ max_sq, max_sk = seqlen_q, seqlen_k
+ cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
+ device=q.device)
+ cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
+ device=kv.device)
+ output = flash_attn_unpadded_kvpacked_func(
+ q, kv, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
+ self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
+ else:
+ nheads = kv.shape[-2]
+ q = rearrange(q, 'b s ... -> (b s) ...')
+ max_sq = seqlen_q
+ cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
+ device=q.device)
+ x = rearrange(kv, 'b s two h d -> b s (two h d)')
+ x_unpad, indices, cu_seqlens_k, max_sk = unpad_input(x, key_padding_mask)
+ x_unpad = rearrange(x_unpad, 'nnz (two h d) -> nnz two h d', two=2, h=nheads)
+ output_unpad = flash_attn_unpadded_kvpacked_func(
+ q, x_unpad, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
+ self.dropout_p if self.training else 0.0,
+ softmax_scale=self.softmax_scale, causal=causal
+ )
+ output = rearrange(output_unpad, '(b s) ... -> b s ...', b=batch_size)
+
+ return output, None
+
+
+class FlashMHA(nn.Module):
+
+ def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
+ causal=False, device=None, dtype=None, **kwargs) -> None:
+ assert batch_first
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.causal = causal
+ self.bias = bias
+
+ self.num_heads = num_heads
+ assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
+ self.head_dim = self.embed_dim // num_heads
+ assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
+
+ self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim)))
+ if bias:
+ self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+ self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self._reset_parameters()
+
+ def _reset_parameters(self) -> None:
+ xavier_uniform_(self.in_proj_weight)
+ if self.in_proj_bias is not None:
+ constant_(self.in_proj_bias, 0.)
+ constant_(self.out_proj.bias, 0.)
+
+ def forward(self, q, k, v, key_padding_mask=None):
+ """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
+ key_padding_mask: bool tensor of shape (batch, seqlen)
+ """
+ q, k, v = _in_projection_packed(q, k, v, self.in_proj_weight, self.in_proj_bias)
+ q = rearrange(q, 'b s (h d) -> b s h d', h=self.num_heads)
+ k = rearrange(k, 'b s (h d) -> b s h d', h=self.num_heads)
+ v = rearrange(v, 'b s (h d) -> b s h d', h=self.num_heads)
+ kv = torch.stack([k, v], dim=2)
+
+ context, attn_weights = self.inner_attn(q, kv, key_padding_mask=key_padding_mask, causal=self.causal)
+ return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights
+
+
+@ATTENTION.register_module()
+class MultiheadFlashAttention(BaseModule):
+ """A wrapper for ``torch.nn.MultiheadAttention``.
+ This module implements MultiheadAttention with identity connection,
+ and positional encoding is also passed as input.
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads.
+ attn_drop (float): A Dropout layer on attn_output_weights.
+ Default: 0.0.
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
+ Default: 0.0.
+ dropout_layer (agent:`ConfigDict`): The dropout_layer used
+ when adding the shortcut.
+ init_cfg (agent:`mmcv.ConfigDict`): The Config for initialization.
+ Default: None.
+ batch_first (bool): When it is True, Key, Query and Value are shape of
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
+ Default to False.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ attn_drop=0.,
+ proj_drop=0.,
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
+ init_cfg=None,
+ batch_first=True,
+ **kwargs):
+ super(MultiheadFlashAttention, self).__init__(init_cfg)
+ if 'dropout' in kwargs:
+ warnings.warn(
+ 'The arguments `dropout` in MultiheadAttention '
+ 'has been deprecated, now you can separately '
+ 'set `attn_drop`(float), proj_drop(float), '
+ 'and `dropout_layer`(dict) ', DeprecationWarning)
+ attn_drop = kwargs['dropout']
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
+
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.batch_first = True
+ self.attn = FlashMHA(
+ embed_dim=embed_dims,
+ num_heads=num_heads,
+ attention_dropout=attn_drop,
+ dtype=torch.float16,
+ device='cuda',
+ **kwargs
+ )
+
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.dropout_layer = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiheadAttention')
+ def forward(self,
+ query,
+ key=None,
+ value=None,
+ identity=None,
+ query_pos=None,
+ key_pos=None,
+ attn_mask=None,
+ key_padding_mask=None,
+ **kwargs):
+ """Forward function for `MultiheadAttention`.
+ **kwargs allow passing a more general data flow when combining
+ with other operations in `transformerlayer`.
+ Args:
+ query (Tensor): The input query with shape [num_queries, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ key (Tensor): The key tensor with shape [num_keys, bs,
+ embed_dims] if self.batch_first is False, else
+ [bs, num_keys, embed_dims] .
+ If None, the ``query`` will be used. Defaults to None.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
+ If None, the `key` will be used.
+ identity (Tensor): This tensor, with the same shape as x,
+ will be used for the identity link.
+ If None, `x` will be used. Defaults to None.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `x`. If not None, it will
+ be added to `x` before forward function. Defaults to None.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Defaults to None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`. Defaults to None.
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
+ Defaults to None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
+ Defaults to None.
+ Returns:
+ Tensor: forwarded results with shape
+ [num_queries, bs, embed_dims]
+ if self.batch_first is False, else
+ [bs, num_queries embed_dims].
+ """
+ if key is None:
+ key = query
+ if value is None:
+ value = key
+ if identity is None:
+ identity = query
+ if key_pos is None:
+ if query_pos is not None:
+ # use query_pos if key_pos is not available
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ else:
+ warnings.warn(f'position encoding of key is'
+ f'missing in {self.__class__.__name__}.')
+ if query_pos is not None:
+ query = query + query_pos
+ if key_pos is not None:
+ key = key + key_pos
+
+ if attn_mask is not None:
+ # FlashAttention does not support arbitrary attn_mask (e.g. DN training masks).
+ # Fall back to standard attention using the same projection weights.
+ import torch.nn.functional as F
+ out, _ = F.multi_head_attention_forward(
+ query.transpose(0, 1), key.transpose(0, 1), value.transpose(0, 1),
+ self.attn.embed_dim, self.attn.num_heads,
+ self.attn.in_proj_weight, self.attn.in_proj_bias,
+ None, None, False,
+ self.attn.inner_attn.dropout_p if self.training else 0.0,
+ self.attn.out_proj.weight, self.attn.out_proj.bias,
+ training=self.training,
+ key_padding_mask=key_padding_mask,
+ attn_mask=attn_mask,
+ )
+ return identity + self.dropout_layer(self.proj_drop(out.transpose(0, 1)))
+
+ # The dataflow('key', 'query', 'value') of ``FlashAttention`` is (batch, num_query, embed_dims).
+ if not self.batch_first:
+ query = query.transpose(0, 1)
+ key = key.transpose(0, 1)
+ value = value.transpose(0, 1)
+
+ out = self.attn(
+ q=query,
+ k=key,
+ v=value,
+ key_padding_mask=key_padding_mask)[0]
+
+ if not self.batch_first:
+ out = out.transpose(0, 1)
+
+ return identity + self.dropout_layer(self.proj_drop(out))
+
+
+def gen_sineembed_for_position(pos_tensor, hidden_dim=256):
+ """Mostly copy-paste from https://github.com/IDEA-opensource/DAB-DETR/
+ """
+ half_hidden_dim = hidden_dim // 2
+ scale = 2 * math.pi
+ dim_t = torch.arange(half_hidden_dim, dtype=torch.float32, device=pos_tensor.device)
+ dim_t = 10000 ** (2 * (dim_t // 2) / half_hidden_dim)
+ x_embed = pos_tensor[..., 0] * scale
+ y_embed = pos_tensor[..., 1] * scale
+ pos_x = x_embed[..., None] / dim_t
+ pos_y = y_embed[..., None] / dim_t
+ pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+ pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+ pos = torch.cat((pos_y, pos_x), dim=-1)
+ return pos
+
diff --git a/projects/mmdet3d_plugin/models/base_target.py b/projects/mmdet3d_plugin/models/base_target.py
new file mode 100644
index 0000000..020a88d
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/base_target.py
@@ -0,0 +1,49 @@
+from abc import ABC, abstractmethod
+
+
+__all__ = ["BaseTargetWithDenoising"]
+
+
+class BaseTargetWithDenoising(ABC):
+ def __init__(self, num_dn_groups=0, num_temp_dn_groups=0):
+ super(BaseTargetWithDenoising, self).__init__()
+ self.num_dn_groups = num_dn_groups
+ self.num_temp_dn_groups = num_temp_dn_groups
+ self.dn_metas = None
+
+ @abstractmethod
+ def sample(self, cls_pred, box_pred, cls_target, box_target):
+ """
+ Perform Hungarian matching between predictions and ground truth,
+ returning the matched ground truth corresponding to the predictions
+ along with the corresponding regression weights.
+ """
+
+ def get_dn_anchors(self, cls_target, box_target, *args, **kwargs):
+ """
+ Generate noisy instances for the current frame, with a total of
+ 'self.num_dn_groups' groups.
+ """
+ return None
+
+ def update_dn(self, instance_feature, anchor, *args, **kwargs):
+ """
+ Insert the previously saved 'self.dn_metas' into the noisy instances
+ of the current frame.
+ """
+
+ def cache_dn(
+ self,
+ dn_instance_feature,
+ dn_anchor,
+ dn_cls_target,
+ valid_mask,
+ dn_id_target,
+ ):
+ """
+ Randomly save information for 'self.num_temp_dn_groups' groups of
+ temporal noisy instances to 'self.dn_metas'.
+ """
+ if self.num_temp_dn_groups < 0:
+ return
+ self.dn_metas = dict(dn_anchor=dn_anchor[:, : self.num_temp_dn_groups])
diff --git a/projects/mmdet3d_plugin/models/blocks.py b/projects/mmdet3d_plugin/models/blocks.py
new file mode 100644
index 0000000..32cacdc
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/blocks.py
@@ -0,0 +1,393 @@
+from typing import List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.cuda.amp.autocast_mode import autocast
+
+from mmcv.cnn import Linear, build_activation_layer, build_norm_layer
+from mmcv.runner.base_module import Sequential, BaseModule
+from mmcv.cnn.bricks.transformer import FFN
+from mmcv.utils import build_from_cfg
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn import xavier_init, constant_init
+from mmcv.cnn.bricks.registry import (
+ ATTENTION,
+ PLUGIN_LAYERS,
+ FEEDFORWARD_NETWORK,
+)
+
+try:
+ from ..ops import deformable_aggregation_function as DAF
+except:
+ DAF = None
+
+__all__ = [
+ "DeformableFeatureAggregation",
+ "DenseDepthNet",
+ "AsymmetricFFN",
+]
+
+
+def linear_relu_ln(embed_dims, in_loops, out_loops, input_dims=None):
+ if input_dims is None:
+ input_dims = embed_dims
+ layers = []
+ for _ in range(out_loops):
+ for _ in range(in_loops):
+ layers.append(Linear(input_dims, embed_dims))
+ layers.append(nn.ReLU(inplace=True))
+ input_dims = embed_dims
+ layers.append(nn.LayerNorm(embed_dims))
+ return layers
+
+
+@ATTENTION.register_module()
+class DeformableFeatureAggregation(BaseModule):
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_groups: int = 8,
+ num_levels: int = 4,
+ num_cams: int = 6,
+ proj_drop: float = 0.0,
+ attn_drop: float = 0.0,
+ kps_generator: dict = None,
+ temporal_fusion_module=None,
+ use_temporal_anchor_embed=True,
+ use_deformable_func=False,
+ use_camera_embed=False,
+ residual_mode="add",
+ ):
+ super(DeformableFeatureAggregation, self).__init__()
+ if embed_dims % num_groups != 0:
+ raise ValueError(
+ f"embed_dims must be divisible by num_groups, "
+ f"but got {embed_dims} and {num_groups}"
+ )
+ self.group_dims = int(embed_dims / num_groups)
+ self.embed_dims = embed_dims
+ self.num_levels = num_levels
+ self.num_groups = num_groups
+ self.num_cams = num_cams
+ self.use_temporal_anchor_embed = use_temporal_anchor_embed
+ if use_deformable_func:
+ assert DAF is not None, "deformable_aggregation needs to be set up."
+ self.use_deformable_func = use_deformable_func
+ self.attn_drop = attn_drop
+ self.residual_mode = residual_mode
+ self.proj_drop = nn.Dropout(proj_drop)
+ kps_generator["embed_dims"] = embed_dims
+ self.kps_generator = build_from_cfg(kps_generator, PLUGIN_LAYERS)
+ self.num_pts = self.kps_generator.num_pts
+ if temporal_fusion_module is not None:
+ if "embed_dims" not in temporal_fusion_module:
+ temporal_fusion_module["embed_dims"] = embed_dims
+ self.temp_module = build_from_cfg(
+ temporal_fusion_module, PLUGIN_LAYERS
+ )
+ else:
+ self.temp_module = None
+ self.output_proj = Linear(embed_dims, embed_dims)
+
+ if use_camera_embed:
+ self.camera_encoder = Sequential(
+ *linear_relu_ln(embed_dims, 1, 2, 12)
+ )
+ self.weights_fc = Linear(
+ embed_dims, num_groups * num_levels * self.num_pts
+ )
+ else:
+ self.camera_encoder = None
+ self.weights_fc = Linear(
+ embed_dims, num_groups * num_cams * num_levels * self.num_pts
+ )
+
+ def init_weight(self):
+ constant_init(self.weights_fc, val=0.0, bias=0.0)
+ xavier_init(self.output_proj, distribution="uniform", bias=0.0)
+
+ def forward(
+ self,
+ instance_feature: torch.Tensor,
+ anchor: torch.Tensor,
+ anchor_embed: torch.Tensor,
+ feature_maps: List[torch.Tensor],
+ metas: dict,
+ **kwargs: dict,
+ ):
+ bs, num_anchor = instance_feature.shape[:2]
+ key_points = self.kps_generator(anchor, instance_feature)
+ weights = self._get_weights(instance_feature, anchor_embed, metas)
+
+ if self.use_deformable_func:
+ points_2d = (
+ self.project_points(
+ key_points,
+ metas["projection_mat"],
+ metas.get("image_wh"),
+ )
+ .permute(0, 2, 3, 1, 4)
+ .reshape(bs, num_anchor, self.num_pts, self.num_cams, 2)
+ )
+ weights = (
+ weights.permute(0, 1, 4, 2, 3, 5)
+ .contiguous()
+ .reshape(
+ bs,
+ num_anchor,
+ self.num_pts,
+ self.num_cams,
+ self.num_levels,
+ self.num_groups,
+ )
+ )
+ features = DAF(*feature_maps, points_2d, weights).reshape(
+ bs, num_anchor, self.embed_dims
+ )
+ else:
+ features = self.feature_sampling(
+ feature_maps,
+ key_points,
+ metas["projection_mat"],
+ metas.get("image_wh"),
+ )
+ features = self.multi_view_level_fusion(features, weights)
+ features = features.sum(dim=2) # fuse multi-point features
+ output = self.proj_drop(self.output_proj(features))
+ if self.residual_mode == "add":
+ output = output + instance_feature
+ elif self.residual_mode == "cat":
+ output = torch.cat([output, instance_feature], dim=-1)
+ return output
+
+ def _get_weights(self, instance_feature, anchor_embed, metas=None):
+ bs, num_anchor = instance_feature.shape[:2]
+ feature = instance_feature + anchor_embed
+ if self.camera_encoder is not None:
+ camera_embed = self.camera_encoder(
+ metas["projection_mat"][:, :, :3].reshape(
+ bs, self.num_cams, -1
+ )
+ )
+ feature = feature[:, :, None] + camera_embed[:, None]
+
+ weights = (
+ self.weights_fc(feature)
+ .reshape(bs, num_anchor, -1, self.num_groups)
+ .softmax(dim=-2)
+ .reshape(
+ bs,
+ num_anchor,
+ self.num_cams,
+ self.num_levels,
+ self.num_pts,
+ self.num_groups,
+ )
+ )
+ if self.training and self.attn_drop > 0:
+ mask = torch.rand(
+ bs, num_anchor, self.num_cams, 1, self.num_pts, 1
+ )
+ mask = mask.to(device=weights.device, dtype=weights.dtype)
+ weights = ((mask > self.attn_drop) * weights) / (
+ 1 - self.attn_drop
+ )
+ return weights
+
+ @staticmethod
+ def project_points(key_points, projection_mat, image_wh=None):
+ bs, num_anchor, num_pts = key_points.shape[:3]
+
+ pts_extend = torch.cat(
+ [key_points, torch.ones_like(key_points[..., :1])], dim=-1
+ )
+ points_2d = torch.matmul(
+ projection_mat[:, :, None, None], pts_extend[:, None, ..., None]
+ ).squeeze(-1)
+ points_2d = points_2d[..., :2] / torch.clamp(
+ points_2d[..., 2:3], min=1e-5
+ )
+ if image_wh is not None:
+ points_2d = points_2d / image_wh[:, :, None, None]
+ return points_2d
+
+ @staticmethod
+ def feature_sampling(
+ feature_maps: List[torch.Tensor],
+ key_points: torch.Tensor,
+ projection_mat: torch.Tensor,
+ image_wh: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ num_levels = len(feature_maps)
+ num_cams = feature_maps[0].shape[1]
+ bs, num_anchor, num_pts = key_points.shape[:3]
+
+ points_2d = DeformableFeatureAggregation.project_points(
+ key_points, projection_mat, image_wh
+ )
+ points_2d = points_2d * 2 - 1
+ points_2d = points_2d.flatten(end_dim=1)
+
+ features = []
+ for fm in feature_maps:
+ features.append(
+ torch.nn.functional.grid_sample(
+ fm.flatten(end_dim=1), points_2d
+ )
+ )
+ features = torch.stack(features, dim=1)
+ features = features.reshape(
+ bs, num_cams, num_levels, -1, num_anchor, num_pts
+ ).permute(
+ 0, 4, 1, 2, 5, 3
+ ) # bs, num_anchor, num_cams, num_levels, num_pts, embed_dims
+
+ return features
+
+ def multi_view_level_fusion(
+ self,
+ features: torch.Tensor,
+ weights: torch.Tensor,
+ ):
+ bs, num_anchor = weights.shape[:2]
+ features = weights[..., None] * features.reshape(
+ features.shape[:-1] + (self.num_groups, self.group_dims)
+ )
+ features = features.sum(dim=2).sum(dim=2)
+ features = features.reshape(
+ bs, num_anchor, self.num_pts, self.embed_dims
+ )
+ return features
+
+
+@PLUGIN_LAYERS.register_module()
+class DenseDepthNet(BaseModule):
+ def __init__(
+ self,
+ embed_dims=256,
+ num_depth_layers=1,
+ equal_focal=100,
+ max_depth=60,
+ loss_weight=1.0,
+ ):
+ super().__init__()
+ self.embed_dims = embed_dims
+ self.equal_focal = equal_focal
+ self.num_depth_layers = num_depth_layers
+ self.max_depth = max_depth
+ self.loss_weight = loss_weight
+
+ self.depth_layers = nn.ModuleList()
+ for i in range(num_depth_layers):
+ self.depth_layers.append(
+ nn.Conv2d(embed_dims, 1, kernel_size=1, stride=1, padding=0)
+ )
+
+ def forward(self, feature_maps, focal=None, gt_depths=None):
+ if focal is None:
+ focal = self.equal_focal
+ else:
+ focal = focal.reshape(-1)
+ depths = []
+ for i, feat in enumerate(feature_maps[: self.num_depth_layers]):
+ depth = self.depth_layers[i](feat.flatten(end_dim=1).float()).exp()
+ depth = depth.transpose(0, -1) * focal / self.equal_focal
+ depth = depth.transpose(0, -1)
+ depths.append(depth)
+ if gt_depths is not None and self.training:
+ loss = self.loss(depths, gt_depths)
+ return loss
+ return depths
+
+ def loss(self, depth_preds, gt_depths):
+ loss = 0.0
+ for pred, gt in zip(depth_preds, gt_depths):
+ pred = pred.permute(0, 2, 3, 1).contiguous().reshape(-1)
+ gt = gt.reshape(-1)
+ fg_mask = torch.logical_and(
+ gt > 0.0, torch.logical_not(torch.isnan(pred))
+ )
+ gt = gt[fg_mask]
+ pred = pred[fg_mask]
+ pred = torch.clip(pred, 0.0, self.max_depth)
+ with autocast(enabled=False):
+ error = torch.abs(pred - gt).sum()
+ _loss = (
+ error
+ / max(1.0, len(gt) * len(depth_preds))
+ * self.loss_weight
+ )
+ loss = loss + _loss
+ return loss
+
+
+@FEEDFORWARD_NETWORK.register_module()
+class AsymmetricFFN(BaseModule):
+ def __init__(
+ self,
+ in_channels=None,
+ pre_norm=None,
+ embed_dims=256,
+ feedforward_channels=1024,
+ num_fcs=2,
+ act_cfg=dict(type="ReLU", inplace=True),
+ ffn_drop=0.0,
+ dropout_layer=None,
+ add_identity=True,
+ init_cfg=None,
+ **kwargs,
+ ):
+ super(AsymmetricFFN, self).__init__(init_cfg)
+ assert num_fcs >= 2, (
+ "num_fcs should be no less " f"than 2. got {num_fcs}."
+ )
+ self.in_channels = in_channels
+ self.pre_norm = pre_norm
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.activate = build_activation_layer(act_cfg)
+
+ layers = []
+ if in_channels is None:
+ in_channels = embed_dims
+ if pre_norm is not None:
+ self.pre_norm = build_norm_layer(pre_norm, in_channels)[1]
+
+ for _ in range(num_fcs - 1):
+ layers.append(
+ Sequential(
+ Linear(in_channels, feedforward_channels),
+ self.activate,
+ nn.Dropout(ffn_drop),
+ )
+ )
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ layers.append(nn.Dropout(ffn_drop))
+ self.layers = Sequential(*layers)
+ self.dropout_layer = (
+ build_dropout(dropout_layer)
+ if dropout_layer
+ else torch.nn.Identity()
+ )
+ self.add_identity = add_identity
+ if self.add_identity:
+ self.identity_fc = (
+ torch.nn.Identity()
+ if in_channels == embed_dims
+ else Linear(self.in_channels, embed_dims)
+ )
+
+ def forward(self, x, identity=None):
+ if self.pre_norm is not None:
+ x = self.pre_norm(x)
+ out = self.layers(x)
+ if not self.add_identity:
+ return self.dropout_layer(out)
+ if identity is None:
+ identity = x
+ identity = self.identity_fc(identity)
+ return identity + self.dropout_layer(out)
diff --git a/projects/mmdet3d_plugin/models/detection3d/__init__.py b/projects/mmdet3d_plugin/models/detection3d/__init__.py
new file mode 100644
index 0000000..b2654b6
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/detection3d/__init__.py
@@ -0,0 +1,9 @@
+from .decoder import SparseBox3DDecoder
+from .target import SparseBox3DTarget
+from .detection3d_blocks import (
+ SparseBox3DRefinementModule,
+ SparseBox3DKeyPointsGenerator,
+ SparseBox3DEncoder,
+)
+from .losses import SparseBox3DLoss
+from .detection3d_head import Sparse4DHead
diff --git a/projects/mmdet3d_plugin/models/detection3d/decoder.py b/projects/mmdet3d_plugin/models/detection3d/decoder.py
new file mode 100644
index 0000000..5d000be
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/detection3d/decoder.py
@@ -0,0 +1,115 @@
+from typing import Optional
+
+import torch
+
+from mmdet.core.bbox.builder import BBOX_CODERS
+
+from projects.mmdet3d_plugin.core.box3d import *
+
+def decode_box(box):
+ yaw = torch.atan2(box[..., SIN_YAW], box[..., COS_YAW])
+ box = torch.cat(
+ [
+ box[..., [X, Y, Z]],
+ box[..., [W, L, H]].exp(),
+ yaw[..., None],
+ box[..., VX:],
+ ],
+ dim=-1,
+ )
+ return box
+
+
+@BBOX_CODERS.register_module()
+class SparseBox3DDecoder(object):
+ def __init__(
+ self,
+ num_output: int = 300,
+ score_threshold: Optional[float] = None,
+ sorted: bool = True,
+ ):
+ super(SparseBox3DDecoder, self).__init__()
+ self.num_output = num_output
+ self.score_threshold = score_threshold
+ self.sorted = sorted
+
+ def decode(
+ self,
+ cls_scores,
+ box_preds,
+ instance_id=None,
+ quality=None,
+ visibility=None,
+ output_idx=-1,
+ ):
+ squeeze_cls = instance_id is not None
+
+ cls_scores = cls_scores[output_idx].sigmoid()
+
+ if squeeze_cls:
+ cls_scores, cls_ids = cls_scores.max(dim=-1)
+ cls_scores = cls_scores.unsqueeze(dim=-1)
+
+ box_preds = box_preds[output_idx]
+ bs, num_pred, num_cls = cls_scores.shape
+ cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
+ self.num_output, dim=1, sorted=self.sorted
+ )
+ if not squeeze_cls:
+ cls_ids = indices % num_cls
+ if self.score_threshold is not None:
+ mask = cls_scores >= self.score_threshold
+
+ if visibility is not None and visibility[output_idx] is None:
+ visibility = None
+ if quality[output_idx] is None:
+ quality = None
+ if quality is not None:
+ centerness = quality[output_idx][..., CNS]
+ centerness = torch.gather(centerness, 1, indices // num_cls)
+ cls_scores_origin = cls_scores.clone()
+ cls_scores *= centerness.sigmoid()
+ cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)
+ if not squeeze_cls:
+ cls_ids = torch.gather(cls_ids, 1, idx)
+ if self.score_threshold is not None:
+ mask = torch.gather(mask, 1, idx)
+ indices = torch.gather(indices, 1, idx)
+
+ output = []
+ for i in range(bs):
+ category_ids = cls_ids[i]
+ if squeeze_cls:
+ category_ids = category_ids[indices[i]]
+ scores = cls_scores[i]
+ box = box_preds[i, indices[i] // num_cls]
+ if self.score_threshold is not None:
+ category_ids = category_ids[mask[i]]
+ scores = scores[mask[i]]
+ box = box[mask[i]]
+ if quality is not None:
+ scores_origin = cls_scores_origin[i]
+ if self.score_threshold is not None:
+ scores_origin = scores_origin[mask[i]]
+
+ box = decode_box(box)
+ output.append(
+ {
+ "boxes_3d": box.cpu(),
+ "scores_3d": scores.cpu(),
+ "labels_3d": category_ids.cpu(),
+ }
+ )
+ if quality is not None:
+ output[-1]["cls_scores"] = scores_origin.cpu()
+ if instance_id is not None:
+ ids = instance_id[i, indices[i]]
+ if self.score_threshold is not None:
+ ids = ids[mask[i]]
+ output[-1]["instance_ids"] = ids
+ if visibility is not None:
+ vis_i = visibility[output_idx][i, indices[i] // num_cls, 0].sigmoid()
+ if self.score_threshold is not None:
+ vis_i = vis_i[mask[i]]
+ output[-1]["visibility_scores"] = vis_i.cpu()
+ return output
diff --git a/projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py b/projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
new file mode 100644
index 0000000..2bcabe8
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/detection3d/detection3d_blocks.py
@@ -0,0 +1,316 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+from mmcv.cnn import Linear, Scale, bias_init_with_prob
+from mmcv.runner.base_module import Sequential, BaseModule
+from mmcv.cnn import xavier_init
+from mmcv.cnn.bricks.registry import (
+ PLUGIN_LAYERS,
+ POSITIONAL_ENCODING,
+)
+
+from projects.mmdet3d_plugin.core.box3d import *
+from ..blocks import linear_relu_ln
+
+__all__ = [
+ "SparseBox3DRefinementModule",
+ "SparseBox3DKeyPointsGenerator",
+ "SparseBox3DEncoder",
+]
+
+
+@POSITIONAL_ENCODING.register_module()
+class SparseBox3DEncoder(BaseModule):
+ def __init__(
+ self,
+ embed_dims,
+ vel_dims=3,
+ mode="add",
+ output_fc=True,
+ in_loops=1,
+ out_loops=2,
+ ):
+ super().__init__()
+ assert mode in ["add", "cat"]
+ self.embed_dims = embed_dims
+ self.vel_dims = vel_dims
+ self.mode = mode
+
+ def embedding_layer(input_dims, output_dims):
+ return nn.Sequential(
+ *linear_relu_ln(output_dims, in_loops, out_loops, input_dims)
+ )
+
+ if not isinstance(embed_dims, (list, tuple)):
+ embed_dims = [embed_dims] * 5
+ self.pos_fc = embedding_layer(3, embed_dims[0])
+ self.size_fc = embedding_layer(3, embed_dims[1])
+ self.yaw_fc = embedding_layer(2, embed_dims[2])
+ if vel_dims > 0:
+ self.vel_fc = embedding_layer(self.vel_dims, embed_dims[3])
+ if output_fc:
+ self.output_fc = embedding_layer(embed_dims[-1], embed_dims[-1])
+ else:
+ self.output_fc = None
+
+ def forward(self, box_3d: torch.Tensor):
+ pos_feat = self.pos_fc(box_3d[..., [X, Y, Z]])
+ size_feat = self.size_fc(box_3d[..., [W, L, H]])
+ yaw_feat = self.yaw_fc(box_3d[..., [SIN_YAW, COS_YAW]])
+ if self.mode == "add":
+ output = pos_feat + size_feat + yaw_feat
+ elif self.mode == "cat":
+ output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1)
+
+ if self.vel_dims > 0:
+ vel_feat = self.vel_fc(box_3d[..., VX : VX + self.vel_dims])
+ if self.mode == "add":
+ output = output + vel_feat
+ elif self.mode == "cat":
+ output = torch.cat([output, vel_feat], dim=-1)
+ if self.output_fc is not None:
+ output = self.output_fc(output)
+ return output
+
+
+@PLUGIN_LAYERS.register_module()
+class SparseBox3DRefinementModule(BaseModule):
+ def __init__(
+ self,
+ embed_dims=256,
+ output_dim=11,
+ num_cls=10,
+ normalize_yaw=False,
+ refine_yaw=False,
+ with_cls_branch=True,
+ with_quality_estimation=False,
+ with_visibility_estimation=False,
+ ):
+ super(SparseBox3DRefinementModule, self).__init__()
+ self.embed_dims = embed_dims
+ self.output_dim = output_dim
+ self.num_cls = num_cls
+ self.normalize_yaw = normalize_yaw
+ self.refine_yaw = refine_yaw
+
+ self.refine_state = [X, Y, Z, W, L, H]
+ if self.refine_yaw:
+ self.refine_state += [SIN_YAW, COS_YAW]
+
+ self.layers = nn.Sequential(
+ *linear_relu_ln(embed_dims, 2, 2),
+ Linear(self.embed_dims, self.output_dim),
+ Scale([1.0] * self.output_dim),
+ )
+ self.with_cls_branch = with_cls_branch
+ if with_cls_branch:
+ self.cls_layers = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 2),
+ Linear(self.embed_dims, self.num_cls),
+ )
+ self.with_quality_estimation = with_quality_estimation
+ if with_quality_estimation:
+ self.quality_layers = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 2),
+ Linear(self.embed_dims, 2),
+ )
+ self.with_visibility_estimation = with_visibility_estimation
+ if with_visibility_estimation:
+ self.visibility_layers = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 2),
+ Linear(self.embed_dims, 1),
+ )
+
+ def init_weight(self):
+ if self.with_cls_branch:
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.cls_layers[-1].bias, bias_init)
+ if self.with_visibility_estimation:
+ # Neutral prior — model starts with no bias toward visible/occluded
+ nn.init.constant_(self.visibility_layers[-1].bias, 0.0)
+
+ def forward(
+ self,
+ instance_feature: torch.Tensor,
+ anchor: torch.Tensor,
+ anchor_embed: torch.Tensor,
+ time_interval: torch.Tensor = 1.0,
+ return_cls=True,
+ ):
+ feature = instance_feature + anchor_embed
+ output = self.layers(feature)
+ output[..., self.refine_state] = (
+ output[..., self.refine_state] + anchor[..., self.refine_state]
+ )
+ if self.normalize_yaw:
+ output[..., [SIN_YAW, COS_YAW]] = torch.nn.functional.normalize(
+ output[..., [SIN_YAW, COS_YAW]], dim=-1
+ )
+ if self.output_dim > 8:
+ if not isinstance(time_interval, torch.Tensor):
+ time_interval = instance_feature.new_tensor(time_interval)
+ translation = torch.transpose(output[..., VX:], 0, -1)
+ velocity = torch.transpose(translation / time_interval, 0, -1)
+ output[..., VX:] = velocity + anchor[..., VX:]
+
+ if return_cls:
+ assert self.with_cls_branch, "Without classification layers !!!"
+ cls = self.cls_layers(instance_feature)
+ else:
+ cls = None
+ if return_cls and self.with_quality_estimation:
+ quality = self.quality_layers(feature)
+ else:
+ quality = None
+ if return_cls and self.with_visibility_estimation:
+ # (bs, N, 1) — raw logit; 1 = sensor-visible, 0 = occluded
+ visibility = self.visibility_layers(instance_feature)
+ else:
+ visibility = None
+ return output, cls, quality, visibility
+
+
+@PLUGIN_LAYERS.register_module()
+class SparseBox3DKeyPointsGenerator(BaseModule):
+ def __init__(
+ self,
+ embed_dims=256,
+ num_learnable_pts=0,
+ fix_scale=None,
+ ):
+ super(SparseBox3DKeyPointsGenerator, self).__init__()
+ self.embed_dims = embed_dims
+ self.num_learnable_pts = num_learnable_pts
+ if fix_scale is None:
+ fix_scale = ((0.0, 0.0, 0.0),)
+ self.fix_scale = nn.Parameter(
+ torch.tensor(fix_scale), requires_grad=False
+ )
+ self.num_pts = len(self.fix_scale) + num_learnable_pts
+ if num_learnable_pts > 0:
+ self.learnable_fc = Linear(self.embed_dims, num_learnable_pts * 3)
+
+ def init_weight(self):
+ if self.num_learnable_pts > 0:
+ xavier_init(self.learnable_fc, distribution="uniform", bias=0.0)
+
+ def forward(
+ self,
+ anchor,
+ instance_feature=None,
+ T_cur2temp_list=None,
+ cur_timestamp=None,
+ temp_timestamps=None,
+ ):
+ bs, num_anchor = anchor.shape[:2]
+ size = anchor[..., None, [W, L, H]].exp()
+ key_points = self.fix_scale * size
+ if self.num_learnable_pts > 0 and instance_feature is not None:
+ learnable_scale = (
+ self.learnable_fc(instance_feature)
+ .reshape(bs, num_anchor, self.num_learnable_pts, 3)
+ .sigmoid()
+ - 0.5
+ )
+ key_points = torch.cat(
+ [key_points, learnable_scale * size], dim=-2
+ )
+
+ rotation_mat = anchor.new_zeros([bs, num_anchor, 3, 3])
+
+ rotation_mat[:, :, 0, 0] = anchor[:, :, COS_YAW]
+ rotation_mat[:, :, 0, 1] = -anchor[:, :, SIN_YAW]
+ rotation_mat[:, :, 1, 0] = anchor[:, :, SIN_YAW]
+ rotation_mat[:, :, 1, 1] = anchor[:, :, COS_YAW]
+ rotation_mat[:, :, 2, 2] = 1
+
+ key_points = torch.matmul(
+ rotation_mat[:, :, None], key_points[..., None]
+ ).squeeze(-1)
+ key_points = key_points + anchor[..., None, [X, Y, Z]]
+
+ if (
+ cur_timestamp is None
+ or temp_timestamps is None
+ or T_cur2temp_list is None
+ or len(temp_timestamps) == 0
+ ):
+ return key_points
+
+ temp_key_points_list = []
+ velocity = anchor[..., VX:]
+ for i, t_time in enumerate(temp_timestamps):
+ time_interval = cur_timestamp - t_time
+ translation = (
+ velocity
+ * time_interval.to(dtype=velocity.dtype)[:, None, None]
+ )
+ temp_key_points = key_points - translation[:, :, None]
+ T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype)
+ temp_key_points = (
+ T_cur2temp[:, None, None, :3]
+ @ torch.cat(
+ [
+ temp_key_points,
+ torch.ones_like(temp_key_points[..., :1]),
+ ],
+ dim=-1,
+ ).unsqueeze(-1)
+ )
+ temp_key_points = temp_key_points.squeeze(-1)
+ temp_key_points_list.append(temp_key_points)
+ return key_points, temp_key_points_list
+
+ @staticmethod
+ def anchor_projection(
+ anchor,
+ T_src2dst_list,
+ src_timestamp=None,
+ dst_timestamps=None,
+ time_intervals=None,
+ ):
+ dst_anchors = []
+ for i in range(len(T_src2dst_list)):
+ vel = anchor[..., VX:]
+ vel_dim = vel.shape[-1]
+ T_src2dst = torch.unsqueeze(
+ T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
+ )
+
+ center = anchor[..., [X, Y, Z]]
+ if time_intervals is not None:
+ time_interval = time_intervals[i]
+ elif src_timestamp is not None and dst_timestamps is not None:
+ time_interval = (src_timestamp - dst_timestamps[i]).to(
+ dtype=vel.dtype
+ )
+ else:
+ time_interval = None
+ if time_interval is not None:
+ translation = vel.transpose(0, -1) * time_interval
+ translation = translation.transpose(0, -1)
+ center = center - translation
+
+ center = (
+ torch.matmul(
+ T_src2dst[..., :3, :3], center[..., None]
+ ).squeeze(dim=-1)
+ + T_src2dst[..., :3, 3]
+ )
+ size = anchor[..., [W, L, H]]
+ yaw = torch.matmul(
+ T_src2dst[..., :2, :2],
+ anchor[..., [COS_YAW, SIN_YAW], None],
+ ).squeeze(-1)
+ yaw = yaw[..., [1,0]]
+ vel = torch.matmul(
+ T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]
+ ).squeeze(-1)
+ dst_anchor = torch.cat([center, size, yaw, vel], dim=-1)
+ dst_anchors.append(dst_anchor)
+ return dst_anchors
+
+ @staticmethod
+ def distance(anchor):
+ return torch.norm(anchor[..., :2], p=2, dim=-1)
diff --git a/projects/mmdet3d_plugin/models/detection3d/detection3d_head.py b/projects/mmdet3d_plugin/models/detection3d/detection3d_head.py
new file mode 100644
index 0000000..da39352
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/detection3d/detection3d_head.py
@@ -0,0 +1,790 @@
+from typing import List, Optional, Tuple, Union
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from mmcv.cnn.bricks.registry import (
+ ATTENTION,
+ PLUGIN_LAYERS,
+ POSITIONAL_ENCODING,
+ FEEDFORWARD_NETWORK,
+ NORM_LAYERS,
+)
+from mmcv.runner import BaseModule, force_fp32
+from mmcv.utils import build_from_cfg
+from mmdet.core.bbox.builder import BBOX_SAMPLERS
+from mmdet.core.bbox.builder import BBOX_CODERS
+from mmdet.models import HEADS, LOSSES
+from mmdet.core import reduce_mean
+
+from ..blocks import DeformableFeatureAggregation as DFG
+
+__all__ = ["Sparse4DHead"]
+
+
+@HEADS.register_module()
+class Sparse4DHead(BaseModule):
+ def __init__(
+ self,
+ instance_bank: dict,
+ anchor_encoder: dict,
+ graph_model: dict,
+ norm_layer: dict,
+ ffn: dict,
+ deformable_model: dict,
+ refine_layer: dict,
+ num_decoder: int = 6,
+ num_single_frame_decoder: int = -1,
+ temp_graph_model: dict = None,
+ loss_cls: dict = None,
+ loss_reg: dict = None,
+ loss_visibility: dict = None,
+ decoder: dict = None,
+ sampler: dict = None,
+ gt_cls_key: str = "gt_labels_3d",
+ gt_reg_key: str = "gt_bboxes_3d",
+ gt_id_key: str = "instance_id",
+ gt_visibility_key: str = "gt_visibility",
+ with_instance_id: bool = True,
+ task_prefix: str = 'det',
+ reg_weights: List = None,
+ operation_order: Optional[List[str]] = None,
+ cls_threshold_to_reg: float = -1,
+ dn_loss_weight: float = 5.0,
+ decouple_attn: bool = True,
+ temporal_warmup_order: Optional[List[str]] = None,
+ warmup_refine_layer: dict = None,
+ warmup_ffn: dict = None,
+ warmup_supervise_all: bool = False,
+ init_cfg: dict = None,
+ **kwargs,
+ ):
+ super(Sparse4DHead, self).__init__(init_cfg)
+ self.num_decoder = num_decoder
+ self.num_single_frame_decoder = num_single_frame_decoder
+ self.gt_cls_key = gt_cls_key
+ self.gt_reg_key = gt_reg_key
+ self.gt_id_key = gt_id_key
+ self.with_instance_id = with_instance_id
+ self.task_prefix = task_prefix
+ self.cls_threshold_to_reg = cls_threshold_to_reg
+ self.dn_loss_weight = dn_loss_weight
+ self.decouple_attn = decouple_attn
+
+ if reg_weights is None:
+ self.reg_weights = [1.0] * 10
+ else:
+ self.reg_weights = reg_weights
+
+ if operation_order is None:
+ operation_order = [
+ "temp_gnn",
+ "gnn",
+ "norm",
+ "deformable",
+ "norm",
+ "ffn",
+ "norm",
+ "refine",
+ ] * num_decoder
+ # delete the 'gnn' and 'norm' layers in the first transformer blocks
+ operation_order = operation_order[3:]
+ self.operation_order = operation_order
+
+ # =========== build modules ===========
+ def build(cfg, registry):
+ if cfg is None:
+ return None
+ return build_from_cfg(cfg, registry)
+
+ self.gt_visibility_key = gt_visibility_key
+ self.instance_bank = build(instance_bank, PLUGIN_LAYERS)
+ self.anchor_encoder = build(anchor_encoder, POSITIONAL_ENCODING)
+ self.sampler = build(sampler, BBOX_SAMPLERS)
+ self.decoder = build(decoder, BBOX_CODERS)
+ self.loss_cls = build(loss_cls, LOSSES)
+ self.loss_reg = build(loss_reg, LOSSES)
+ self.loss_visibility = build(loss_visibility, LOSSES) if loss_visibility else None
+ self.op_config_map = {
+ "temp_gnn": [temp_graph_model, ATTENTION],
+ "gnn": [graph_model, ATTENTION],
+ "norm": [norm_layer, NORM_LAYERS],
+ "ffn": [ffn, FEEDFORWARD_NETWORK],
+ "deformable": [deformable_model, ATTENTION],
+ "refine": [refine_layer, PLUGIN_LAYERS],
+ }
+ self.layers = nn.ModuleList(
+ [
+ build(*self.op_config_map.get(op, [None, None]))
+ for op in self.operation_order
+ ]
+ )
+ self.temporal_warmup_order = list(temporal_warmup_order) if temporal_warmup_order else []
+ # For "refine" in warmup, use warmup_refine_layer if provided (should have
+ # with_cls_branch=False, with_quality_estimation=False to avoid unused params).
+ # Falls back to refine_layer if warmup_refine_layer is not specified.
+ warmup_op_config_map = dict(self.op_config_map)
+ if warmup_refine_layer is not None:
+ warmup_op_config_map["refine"] = [warmup_refine_layer, PLUGIN_LAYERS]
+ if warmup_ffn is not None:
+ warmup_op_config_map["ffn"] = [warmup_ffn, FEEDFORWARD_NETWORK]
+ self.warmup_layers = nn.ModuleList(
+ [
+ build(*warmup_op_config_map.get(op, [None, None]))
+ for op in self.temporal_warmup_order
+ ]
+ )
+ self.embed_dims = self.instance_bank.embed_dims
+ if self.decouple_attn:
+ self.fc_before = nn.Linear(
+ self.embed_dims, self.embed_dims * 2, bias=False
+ )
+ self.fc_after = nn.Linear(
+ self.embed_dims * 2, self.embed_dims, bias=False
+ )
+ else:
+ self.fc_before = nn.Identity()
+ self.fc_after = nn.Identity()
+ # Dedicated fc projections for warmup GNN — must NOT share with fc_before/fc_after
+ # because gradient checkpointing would fire DDP hooks twice for shared params.
+ has_warmup_gnn = any(op == "gnn" for op in self.temporal_warmup_order)
+ if has_warmup_gnn and self.decouple_attn:
+ self.warmup_fc_before = nn.Linear(
+ self.embed_dims, self.embed_dims * 2, bias=False
+ )
+ self.warmup_fc_after = nn.Linear(
+ self.embed_dims * 2, self.embed_dims, bias=False
+ )
+ else:
+ self.warmup_fc_before = nn.Identity()
+ self.warmup_fc_after = nn.Identity()
+ self.warmup_supervise_all = warmup_supervise_all
+
+ def init_weights(self):
+ for i, op in enumerate(self.operation_order):
+ if self.layers[i] is None:
+ continue
+ elif op != "refine":
+ for p in self.layers[i].parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for i, op in enumerate(self.temporal_warmup_order):
+ if self.warmup_layers[i] is None:
+ continue
+ elif op != "refine":
+ for p in self.warmup_layers[i].parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ if isinstance(self.warmup_fc_before, nn.Linear):
+ nn.init.xavier_uniform_(self.warmup_fc_before.weight)
+ if isinstance(self.warmup_fc_after, nn.Linear):
+ nn.init.xavier_uniform_(self.warmup_fc_after.weight)
+ for m in self.modules():
+ if hasattr(m, "init_weight"):
+ m.init_weight()
+
+ def _gnn_with_layer(self, layer, feat, anchor_embed):
+ """Self-attention (GNN) using an explicit layer rather than self.layers[i].
+ Used by the temporal warmup block where queries attend only to each other.
+ Uses warmup_fc_before/after (not shared with main decoder fc projections)."""
+ if self.decouple_attn:
+ q = torch.cat([feat, anchor_embed], dim=-1)
+ v = self.warmup_fc_before(feat)
+ return self.warmup_fc_after(layer(q, q, v))
+ else:
+ v = self.warmup_fc_before(feat)
+ return self.warmup_fc_after(layer(feat, feat, v, query_pos=anchor_embed))
+
+ def graph_model(
+ self,
+ index,
+ query,
+ key=None,
+ value=None,
+ query_pos=None,
+ key_pos=None,
+ **kwargs,
+ ):
+ if self.decouple_attn:
+ query = torch.cat([query, query_pos], dim=-1)
+ if key is not None:
+ key = torch.cat([key, key_pos], dim=-1)
+ query_pos, key_pos = None, None
+ if value is not None:
+ value = self.fc_before(value)
+ return self.fc_after(
+ self.layers[index](
+ query,
+ key,
+ value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ **kwargs,
+ )
+ )
+
+ def forward(
+ self,
+ feature_maps: Union[torch.Tensor, List],
+ metas: dict,
+ ):
+ if isinstance(feature_maps, torch.Tensor):
+ feature_maps = [feature_maps]
+ batch_size = feature_maps[0].shape[0]
+
+ # ========= get instance info ============
+ if (
+ self.sampler.dn_metas is not None
+ and self.sampler.dn_metas["dn_anchor"].shape[0] != batch_size
+ ):
+ self.sampler.dn_metas = None
+ (
+ instance_feature,
+ anchor,
+ temp_instance_feature,
+ temp_anchor,
+ time_interval,
+ ) = self.instance_bank.get(
+ batch_size, metas, dn_metas=self.sampler.dn_metas
+ )
+
+ # ========= prepare for denosing training ============
+ # 1. get dn metas: noisy-anchors and corresponding GT
+ # 2. concat learnable instances and noisy instances
+ # 3. get attention mask
+ attn_mask = None
+ dn_metas = None
+ temp_dn_reg_target = None
+ if self.training and hasattr(self.sampler, "get_dn_anchors"):
+ if self.gt_id_key in metas["img_metas"][0]:
+ gt_instance_id = [
+ torch.from_numpy(x[self.gt_id_key]).cuda()
+ for x in metas["img_metas"]
+ ]
+ else:
+ gt_instance_id = None
+ dn_metas = self.sampler.get_dn_anchors(
+ metas[self.gt_cls_key],
+ metas[self.gt_reg_key],
+ gt_instance_id,
+ )
+ if dn_metas is not None:
+ (
+ dn_anchor,
+ dn_reg_target,
+ dn_cls_target,
+ dn_attn_mask,
+ valid_mask,
+ dn_id_target,
+ ) = dn_metas
+ num_dn_anchor = dn_anchor.shape[1]
+ if dn_anchor.shape[-1] != anchor.shape[-1]:
+ remain_state_dims = anchor.shape[-1] - dn_anchor.shape[-1]
+ dn_anchor = torch.cat(
+ [
+ dn_anchor,
+ dn_anchor.new_zeros(
+ batch_size, num_dn_anchor, remain_state_dims
+ ),
+ ],
+ dim=-1,
+ )
+ anchor = torch.cat([anchor, dn_anchor], dim=1)
+ instance_feature = torch.cat(
+ [
+ instance_feature,
+ instance_feature.new_zeros(
+ batch_size, num_dn_anchor, instance_feature.shape[-1]
+ ),
+ ],
+ dim=1,
+ )
+ num_instance = instance_feature.shape[1]
+ num_free_instance = num_instance - num_dn_anchor
+ attn_mask = anchor.new_ones(
+ (num_instance, num_instance), dtype=torch.bool
+ )
+ attn_mask[:num_free_instance, :num_free_instance] = False
+ attn_mask[num_free_instance:, num_free_instance:] = dn_attn_mask
+
+ anchor_embed = self.anchor_encoder(anchor)
+ if temp_anchor is not None:
+ temp_anchor_embed = self.anchor_encoder(temp_anchor)
+ else:
+ temp_anchor_embed = None
+
+ # =========== temporal warmup (Block 0) ====================
+ # Social self-attention among temporal queries before they are merged
+ # with current-frame detections. No image features used here.
+ # Always runs (even on first frame) so warmup params always receive
+ # gradients — avoids the need for find_unused_parameters=True.
+ if self.temporal_warmup_order:
+ if temp_instance_feature is not None:
+ # Temporal case: warm up the cached temporal features
+ w_feat = temp_instance_feature
+ w_anchor = temp_anchor
+ w_anchor_embed = temp_anchor_embed
+ is_temporal = True
+ else:
+ # First frame: warm up the first num_temp_instances current slots
+ num_ti = self.instance_bank.num_temp_instances
+ w_feat = instance_feature[:, :num_ti]
+ w_anchor = anchor[:, :num_ti]
+ w_anchor_embed = anchor_embed[:, :num_ti]
+ is_temporal = False
+ w_cls, w_qt, w_vis = None, None, None
+ for i, op in enumerate(self.temporal_warmup_order):
+ if self.warmup_layers[i] is None:
+ continue
+ if op == "gnn":
+ w_feat = self._gnn_with_layer(
+ self.warmup_layers[i], w_feat, w_anchor_embed
+ )
+ elif op in ("norm", "ffn"):
+ w_feat = self.warmup_layers[i](w_feat)
+ elif op == "refine":
+ w_anchor, w_cls, w_qt, w_vis = self.warmup_layers[i](
+ w_feat,
+ w_anchor,
+ w_anchor_embed,
+ time_interval=time_interval,
+ return_cls=True,
+ )
+ w_anchor_embed = self.anchor_encoder(w_anchor)
+ if is_temporal:
+ temp_instance_feature = w_feat
+ temp_anchor = w_anchor
+ temp_anchor_embed = w_anchor_embed
+ else:
+ # Inject warmed first-frame features back so warmup params
+ # connect to the loss via the main decoder.
+ num_ti = self.instance_bank.num_temp_instances
+ instance_feature = torch.cat(
+ [w_feat, instance_feature[:, num_ti:]], dim=1
+ )
+ anchor = torch.cat([w_anchor, anchor[:, num_ti:]], dim=1)
+ anchor_embed = torch.cat(
+ [w_anchor_embed, anchor_embed[:, num_ti:]], dim=1
+ )
+
+ # =================== forward the layers ====================
+ prediction = []
+ classification = []
+ quality = []
+ visibility = []
+ num_warmup_preds = 0
+ # If warmup produced a refine prediction on temporal instances, prepend it
+ # so it gets supervised like any other intermediate decoder stage.
+ # Pads non-temporal slots (num_ti:num_anchor) with initial anchor positions
+ # and near-zero cls logits so the sampler treats them as background.
+ # NOTE: do NOT gate this on is_temporal — the cls branch must always
+ # participate in the loss so DDP doesn't see unused parameters on
+ # first-frame batches (which have is_temporal=False).
+ if (
+ self.temporal_warmup_order
+ and w_cls is not None
+ and dn_metas is None
+ ):
+ num_ti = self.instance_bank.num_temp_instances
+ num_anchor = self.instance_bank.num_anchor
+ warmup_pred = torch.cat(
+ [w_anchor, anchor[:, num_ti:num_anchor]], dim=1
+ )
+ warmup_cls = torch.cat(
+ [
+ w_cls,
+ w_cls.new_full(
+ [batch_size, num_anchor - num_ti, w_cls.shape[-1]], -10.0
+ ),
+ ],
+ dim=1,
+ )
+ warmup_qt = (
+ torch.cat(
+ [
+ w_qt,
+ w_qt.new_zeros(batch_size, num_anchor - num_ti, w_qt.shape[-1]),
+ ],
+ dim=1,
+ )
+ if w_qt is not None
+ else None
+ )
+ warmup_vis = (
+ torch.cat(
+ [
+ w_vis,
+ w_vis.new_zeros(batch_size, num_anchor - num_ti, w_vis.shape[-1]),
+ ],
+ dim=1,
+ )
+ if w_vis is not None
+ else None
+ )
+ prediction.append(warmup_pred)
+ classification.append(warmup_cls)
+ quality.append(warmup_qt)
+ num_warmup_preds = 1
+ visibility.append(warmup_vis)
+ num_main_decoder_refines = 0
+ for i, op in enumerate(self.operation_order):
+ if self.layers[i] is None:
+ continue
+ elif op == "temp_gnn":
+ instance_feature = self.graph_model(
+ i,
+ instance_feature,
+ temp_instance_feature,
+ temp_instance_feature,
+ query_pos=anchor_embed,
+ key_pos=temp_anchor_embed,
+ attn_mask=attn_mask
+ if temp_instance_feature is None
+ else None,
+ )
+ elif op == "gnn":
+ instance_feature = self.graph_model(
+ i,
+ instance_feature,
+ value=instance_feature,
+ query_pos=anchor_embed,
+ attn_mask=attn_mask,
+ )
+ elif op == "norm" or op == "ffn":
+ instance_feature = self.layers[i](instance_feature)
+ elif op == "deformable":
+ instance_feature = self.layers[i](
+ instance_feature,
+ anchor,
+ anchor_embed,
+ feature_maps,
+ metas,
+ )
+ elif op == "refine":
+ anchor, cls, qt, vis = self.layers[i](
+ instance_feature,
+ anchor,
+ anchor_embed,
+ time_interval=time_interval,
+ return_cls=True,
+ )
+ prediction.append(anchor)
+ classification.append(cls)
+ quality.append(qt)
+ visibility.append(vis)
+ num_main_decoder_refines += 1
+ if num_main_decoder_refines == self.num_single_frame_decoder:
+ instance_feature, anchor = self.instance_bank.update(
+ instance_feature, anchor, cls,
+ cached_feature_override=temp_instance_feature,
+ cached_anchor_override=temp_anchor,
+ )
+ if (
+ dn_metas is not None
+ and self.sampler.num_temp_dn_groups > 0
+ and dn_id_target is not None
+ ):
+ (
+ instance_feature,
+ anchor,
+ temp_dn_reg_target,
+ temp_dn_cls_target,
+ temp_valid_mask,
+ dn_id_target,
+ ) = self.sampler.update_dn(
+ instance_feature,
+ anchor,
+ dn_reg_target,
+ dn_cls_target,
+ valid_mask,
+ dn_id_target,
+ self.instance_bank.num_anchor,
+ self.instance_bank.mask,
+ )
+ anchor_embed = self.anchor_encoder(anchor)
+ if (
+ len(prediction) > self.num_single_frame_decoder
+ and temp_anchor_embed is not None
+ ):
+ temp_anchor_embed = anchor_embed[
+ :, : self.instance_bank.num_temp_instances
+ ]
+ else:
+ raise NotImplementedError(f"{op} is not supported.")
+
+ output = {}
+
+ # split predictions of learnable instances and noisy instances
+ if dn_metas is not None:
+ dn_classification = [
+ x[:, num_free_instance:] for x in classification
+ ]
+ classification = [x[:, :num_free_instance] for x in classification]
+ dn_prediction = [x[:, num_free_instance:] for x in prediction]
+ prediction = [x[:, :num_free_instance] for x in prediction]
+ quality = [
+ x[:, :num_free_instance] if x is not None else None
+ for x in quality
+ ]
+ visibility = [
+ x[:, :num_free_instance] if x is not None else None
+ for x in visibility
+ ]
+ output.update(
+ {
+ "dn_prediction": dn_prediction,
+ "dn_classification": dn_classification,
+ "dn_reg_target": dn_reg_target,
+ "dn_cls_target": dn_cls_target,
+ "dn_valid_mask": valid_mask,
+ }
+ )
+ if temp_dn_reg_target is not None:
+ output.update(
+ {
+ "temp_dn_reg_target": temp_dn_reg_target,
+ "temp_dn_cls_target": temp_dn_cls_target,
+ "temp_dn_valid_mask": temp_valid_mask,
+ "dn_id_target": dn_id_target,
+ }
+ )
+ dn_cls_target = temp_dn_cls_target
+ valid_mask = temp_valid_mask
+ dn_instance_feature = instance_feature[:, num_free_instance:]
+ dn_anchor = anchor[:, num_free_instance:]
+ instance_feature = instance_feature[:, :num_free_instance]
+ anchor_embed = anchor_embed[:, :num_free_instance]
+ anchor = anchor[:, :num_free_instance]
+ cls = cls[:, :num_free_instance]
+
+ # cache dn_metas for temporal denoising
+ self.sampler.cache_dn(
+ dn_instance_feature,
+ dn_anchor,
+ dn_cls_target,
+ valid_mask,
+ dn_id_target,
+ )
+ output.update(
+ {
+ "classification": classification,
+ "prediction": prediction,
+ "quality": quality,
+ "visibility": visibility,
+ "instance_feature": instance_feature,
+ "anchor_embed": anchor_embed,
+ "num_warmup_preds": num_warmup_preds,
+ }
+ )
+
+ # cache current instances for temporal modeling
+ self.instance_bank.cache(
+ instance_feature, anchor, cls, metas, feature_maps
+ )
+ if self.with_instance_id:
+ instance_id = self.instance_bank.get_instance_id(
+ cls, anchor, self.decoder.score_threshold
+ )
+ output["instance_id"] = instance_id
+ return output
+
+ @force_fp32(apply_to=("model_outs"))
+ def loss(self, model_outs, data, feature_maps=None):
+ # ===================== prediction losses ======================
+ cls_scores = model_outs["classification"]
+ reg_preds = model_outs["prediction"]
+ quality = model_outs["quality"]
+ vis_scores = model_outs.get("visibility", [None] * len(cls_scores))
+ num_warmup_preds = model_outs.get("num_warmup_preds", 0)
+ output = {}
+ for decoder_idx, (cls, reg, qt, vis) in enumerate(
+ zip(cls_scores, reg_preds, quality, vis_scores)
+ ):
+ reg = reg[..., : len(self.reg_weights)]
+ if (
+ self.warmup_supervise_all
+ and num_warmup_preds <= decoder_idx < num_warmup_preds + self.num_single_frame_decoder
+ and self.gt_visibility_key in data
+ ):
+ gt_vis = data[self.gt_visibility_key]
+ gt_cls = [l[v > 0] for l, v in zip(data[self.gt_cls_key], gt_vis)]
+ gt_reg = [b[v > 0] for b, v in zip(data[self.gt_reg_key], gt_vis)]
+ else:
+ gt_cls = data[self.gt_cls_key]
+ gt_reg = data[self.gt_reg_key]
+ cls_target, reg_target, reg_weights = self.sampler.sample(
+ cls,
+ reg,
+ gt_cls,
+ gt_reg,
+ )
+ reg_target = reg_target[..., : len(self.reg_weights)]
+ reg_target_full = reg_target.clone()
+ mask = torch.logical_not(torch.all(reg_target == 0, dim=-1))
+ mask_valid = mask.clone()
+
+ num_pos = max(
+ reduce_mean(torch.sum(mask).to(dtype=reg.dtype)), 1.0
+ )
+ if self.cls_threshold_to_reg > 0:
+ threshold = self.cls_threshold_to_reg
+ mask = torch.logical_and(
+ mask, cls.max(dim=-1).values.sigmoid() > threshold
+ )
+
+ cls = cls.flatten(end_dim=1)
+ cls_target = cls_target.flatten(end_dim=1)
+ cls_loss = self.loss_cls(cls, cls_target, avg_factor=num_pos)
+
+ mask = mask.reshape(-1)
+ reg_weights = reg_weights * reg.new_tensor(self.reg_weights)
+ reg_target = reg_target.flatten(end_dim=1)[mask]
+ reg = reg.flatten(end_dim=1)[mask]
+ reg_weights = reg_weights.flatten(end_dim=1)[mask]
+ reg_target = torch.where(
+ reg_target.isnan(), reg.new_tensor(0.0), reg_target
+ )
+ cls_target = cls_target[mask]
+ if qt is not None:
+ qt = qt.flatten(end_dim=1)[mask]
+
+ reg_loss = self.loss_reg(
+ reg,
+ reg_target,
+ weight=reg_weights,
+ avg_factor=num_pos,
+ prefix=f"{self.task_prefix}_",
+ suffix=f"_{decoder_idx}",
+ quality=qt,
+ cls_target=cls_target,
+ )
+
+ output[f"{self.task_prefix}_loss_cls_{decoder_idx}"] = cls_loss
+ output.update(reg_loss)
+
+ # ---- visibility loss (only on matched / positive anchors) ----
+ if (
+ vis is not None
+ and self.loss_visibility is not None
+ and self.gt_visibility_key in data
+ ):
+ gt_vis_list = data[self.gt_visibility_key]
+ bs_v, num_pred_v = vis.shape[:2]
+ vis_target = vis.new_zeros(bs_v, num_pred_v)
+ for b_i, (pred_idx, target_idx) in enumerate(
+ self.sampler.indices
+ ):
+ if (
+ pred_idx is not None
+ and len(pred_idx) > 0
+ and len(gt_vis_list[b_i]) > 0
+ ):
+ vis_target[b_i, pred_idx] = (
+ gt_vis_list[b_i]
+ .to(vis.device)
+ .float()[target_idx]
+ )
+ matched = mask_valid.reshape(-1)
+ vis_loss = self.loss_visibility(
+ vis.squeeze(-1).flatten(end_dim=1)[matched].unsqueeze(-1),
+ vis_target.flatten(end_dim=1)[matched].long(),
+ avg_factor=num_pos,
+ )
+ output[
+ f"{self.task_prefix}_loss_visibility_{decoder_idx}"
+ ] = vis_loss
+
+ if "dn_prediction" not in model_outs:
+ return output
+
+ # ===================== denoising losses ======================
+ dn_cls_scores = model_outs["dn_classification"]
+ dn_reg_preds = model_outs["dn_prediction"]
+
+ (
+ dn_valid_mask,
+ dn_cls_target,
+ dn_reg_target,
+ dn_pos_mask,
+ reg_weights,
+ num_dn_pos,
+ ) = self.prepare_for_dn_loss(model_outs)
+ for decoder_idx, (cls, reg) in enumerate(
+ zip(dn_cls_scores, dn_reg_preds)
+ ):
+ if (
+ "temp_dn_valid_mask" in model_outs
+ and decoder_idx == self.num_single_frame_decoder
+ ):
+ (
+ dn_valid_mask,
+ dn_cls_target,
+ dn_reg_target,
+ dn_pos_mask,
+ reg_weights,
+ num_dn_pos,
+ ) = self.prepare_for_dn_loss(model_outs, prefix="temp_")
+
+ cls_loss = self.loss_cls(
+ cls.flatten(end_dim=1)[dn_valid_mask],
+ dn_cls_target,
+ avg_factor=num_dn_pos,
+ )
+ reg_loss = self.loss_reg(
+ reg.flatten(end_dim=1)[dn_valid_mask][dn_pos_mask][
+ ..., : len(self.reg_weights)
+ ],
+ dn_reg_target,
+ avg_factor=num_dn_pos,
+ weight=reg_weights,
+ prefix=f"{self.task_prefix}_",
+ suffix=f"_dn_{decoder_idx}",
+ )
+ output[f"{self.task_prefix}_loss_cls_dn_{decoder_idx}"] = cls_loss
+ output.update(reg_loss)
+ return output
+
+ def prepare_for_dn_loss(self, model_outs, prefix=""):
+ dn_valid_mask = model_outs[f"{prefix}dn_valid_mask"].flatten(end_dim=1)
+ dn_cls_target = model_outs[f"{prefix}dn_cls_target"].flatten(
+ end_dim=1
+ )[dn_valid_mask]
+ dn_reg_target = model_outs[f"{prefix}dn_reg_target"].flatten(
+ end_dim=1
+ )[dn_valid_mask][..., : len(self.reg_weights)]
+ dn_pos_mask = dn_cls_target >= 0
+ dn_reg_target = dn_reg_target[dn_pos_mask]
+ reg_weights = dn_reg_target.new_tensor(self.reg_weights)[None].tile(
+ dn_reg_target.shape[0], 1
+ )
+ num_dn_pos = max(
+ reduce_mean(torch.sum(dn_valid_mask).to(dtype=reg_weights.dtype)),
+ 1.0,
+ )
+ return (
+ dn_valid_mask,
+ dn_cls_target,
+ dn_reg_target,
+ dn_pos_mask,
+ reg_weights,
+ num_dn_pos,
+ )
+
+ @force_fp32(apply_to=("model_outs"))
+ def post_process(self, model_outs, output_idx=-1):
+ vis_list = model_outs.get("visibility")
+ # Only pass visibility to decode() when the head is active (not all-None).
+ # This keeps backward-compatibility with decoders that lack the param.
+ vis_kwarg = {}
+ if vis_list is not None and any(v is not None for v in vis_list):
+ vis_kwarg = {"visibility": vis_list}
+ return self.decoder.decode(
+ model_outs["classification"],
+ model_outs["prediction"],
+ instance_id=model_outs.get("instance_id"),
+ quality=model_outs.get("quality"),
+ output_idx=output_idx,
+ **vis_kwarg,
+ )
diff --git a/projects/mmdet3d_plugin/models/detection3d/losses.py b/projects/mmdet3d_plugin/models/detection3d/losses.py
new file mode 100644
index 0000000..f4d6656
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/detection3d/losses.py
@@ -0,0 +1,93 @@
+import torch
+import torch.nn as nn
+
+from mmcv.utils import build_from_cfg
+from mmdet.models.builder import LOSSES
+
+from projects.mmdet3d_plugin.core.box3d import *
+
+
+@LOSSES.register_module()
+class SparseBox3DLoss(nn.Module):
+ def __init__(
+ self,
+ loss_box,
+ loss_centerness=None,
+ loss_yawness=None,
+ cls_allow_reverse=None,
+ ):
+ super().__init__()
+
+ def build(cfg, registry):
+ if cfg is None:
+ return None
+ return build_from_cfg(cfg, registry)
+
+ self.loss_box = build(loss_box, LOSSES)
+ self.loss_cns = build(loss_centerness, LOSSES)
+ self.loss_yns = build(loss_yawness, LOSSES)
+ self.cls_allow_reverse = cls_allow_reverse
+
+ def forward(
+ self,
+ box,
+ box_target,
+ weight=None,
+ avg_factor=None,
+ prefix="",
+ suffix="",
+ quality=None,
+ cls_target=None,
+ **kwargs,
+ ):
+ # Some categories do not distinguish between positive and negative
+ # directions. For example, barrier in nuScenes dataset.
+ if self.cls_allow_reverse is not None and cls_target is not None:
+ if_reverse = (
+ torch.nn.functional.cosine_similarity(
+ box_target[..., [SIN_YAW, COS_YAW]],
+ box[..., [SIN_YAW, COS_YAW]],
+ dim=-1,
+ )
+ < 0
+ )
+ if_reverse = (
+ torch.isin(
+ cls_target, cls_target.new_tensor(self.cls_allow_reverse)
+ )
+ & if_reverse
+ )
+ box_target[..., [SIN_YAW, COS_YAW]] = torch.where(
+ if_reverse[..., None],
+ -box_target[..., [SIN_YAW, COS_YAW]],
+ box_target[..., [SIN_YAW, COS_YAW]],
+ )
+
+ output = {}
+ box_loss = self.loss_box(
+ box, box_target, weight=weight, avg_factor=avg_factor
+ )
+ output[f"{prefix}loss_box{suffix}"] = box_loss
+
+ if quality is not None:
+ cns = quality[..., CNS]
+ yns = quality[..., YNS].sigmoid()
+ cns_target = torch.norm(
+ box_target[..., [X, Y, Z]] - box[..., [X, Y, Z]], p=2, dim=-1
+ )
+ cns_target = torch.exp(-cns_target)
+ cns_loss = self.loss_cns(cns, cns_target, avg_factor=avg_factor)
+ output[f"{prefix}loss_cns{suffix}"] = cns_loss
+
+ yns_target = (
+ torch.nn.functional.cosine_similarity(
+ box_target[..., [SIN_YAW, COS_YAW]],
+ box[..., [SIN_YAW, COS_YAW]],
+ dim=-1,
+ )
+ > 0
+ )
+ yns_target = yns_target.float()
+ yns_loss = self.loss_yns(yns, yns_target, avg_factor=avg_factor)
+ output[f"{prefix}loss_yns{suffix}"] = yns_loss
+ return output
diff --git a/projects/mmdet3d_plugin/models/detection3d/target.py b/projects/mmdet3d_plugin/models/detection3d/target.py
new file mode 100644
index 0000000..fa284e7
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/detection3d/target.py
@@ -0,0 +1,437 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+
+from mmdet.core.bbox.builder import BBOX_SAMPLERS
+
+from projects.mmdet3d_plugin.core.box3d import *
+from ..base_target import BaseTargetWithDenoising
+
+
+__all__ = ["SparseBox3DTarget"]
+
+
+@BBOX_SAMPLERS.register_module()
+class SparseBox3DTarget(BaseTargetWithDenoising):
+ def __init__(
+ self,
+ cls_weight=2.0,
+ alpha=0.25,
+ gamma=2,
+ eps=1e-12,
+ box_weight=0.25,
+ reg_weights=None,
+ cls_wise_reg_weights=None,
+ num_dn_groups=0,
+ dn_noise_scale=0.5,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ num_temp_dn_groups=0,
+ ):
+ super(SparseBox3DTarget, self).__init__(
+ num_dn_groups, num_temp_dn_groups
+ )
+ self.cls_weight = cls_weight
+ self.box_weight = box_weight
+ self.alpha = alpha
+ self.gamma = gamma
+ self.eps = eps
+ self.reg_weights = reg_weights
+ if self.reg_weights is None:
+ self.reg_weights = [1.0] * 8 + [0.0] * 2
+ self.cls_wise_reg_weights = cls_wise_reg_weights
+ self.dn_noise_scale = dn_noise_scale
+ self.max_dn_gt = max_dn_gt
+ self.add_neg_dn = add_neg_dn
+
+ def encode_reg_target(self, box_target, device=None):
+ outputs = []
+ for box in box_target:
+ output = torch.cat(
+ [
+ box[..., [X, Y, Z]],
+ box[..., [W, L, H]].log(),
+ torch.sin(box[..., YAW]).unsqueeze(-1),
+ torch.cos(box[..., YAW]).unsqueeze(-1),
+ box[..., YAW + 1 :],
+ ],
+ dim=-1,
+ )
+ if device is not None:
+ output = output.to(device=device)
+ outputs.append(output)
+ return outputs
+
+ def sample(
+ self,
+ cls_pred,
+ box_pred,
+ cls_target,
+ box_target,
+ ):
+ bs, num_pred, num_cls = cls_pred.shape
+
+ cls_cost = self._cls_cost(cls_pred, cls_target)
+
+ box_target = self.encode_reg_target(box_target, box_pred.device)
+
+ instance_reg_weights = []
+ for i in range(len(box_target)):
+ weights = torch.logical_not(box_target[i].isnan()).to(
+ dtype=box_target[i].dtype
+ )
+ if self.cls_wise_reg_weights is not None:
+ for cls, weight in self.cls_wise_reg_weights.items():
+ weights = torch.where(
+ (cls_target[i] == cls)[:, None],
+ weights.new_tensor(weight),
+ weights,
+ )
+ instance_reg_weights.append(weights)
+ box_cost = self._box_cost(box_pred, box_target, instance_reg_weights)
+
+ indices = []
+ for i in range(bs):
+ if cls_cost[i] is not None and box_cost[i] is not None:
+ cost = (cls_cost[i] + box_cost[i]).detach().cpu().numpy()
+ cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost)
+ assign = linear_sum_assignment(cost)
+ indices.append(
+ [cls_pred.new_tensor(x, dtype=torch.int64) for x in assign]
+ )
+ else:
+ indices.append([None, None])
+
+ output_cls_target = (
+ cls_target[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls
+ )
+ output_box_target = box_pred.new_zeros(box_pred.shape)
+ output_reg_weights = box_pred.new_zeros(box_pred.shape)
+ for i, (pred_idx, target_idx) in enumerate(indices):
+ if len(cls_target[i]) == 0:
+ continue
+ output_cls_target[i, pred_idx] = cls_target[i][target_idx]
+ output_box_target[i, pred_idx] = box_target[i][target_idx]
+ output_reg_weights[i, pred_idx] = instance_reg_weights[i][
+ target_idx
+ ]
+ self.indices = indices
+ return output_cls_target, output_box_target, output_reg_weights
+
+ def _cls_cost(self, cls_pred, cls_target):
+ bs = cls_pred.shape[0]
+ cls_pred = cls_pred.sigmoid()
+ cost = []
+ for i in range(bs):
+ if len(cls_target[i]) > 0:
+ neg_cost = (
+ -(1 - cls_pred[i] + self.eps).log()
+ * (1 - self.alpha)
+ * cls_pred[i].pow(self.gamma)
+ )
+ pos_cost = (
+ -(cls_pred[i] + self.eps).log()
+ * self.alpha
+ * (1 - cls_pred[i]).pow(self.gamma)
+ )
+ cost.append(
+ (pos_cost[:, cls_target[i]] - neg_cost[:, cls_target[i]])
+ * self.cls_weight
+ )
+ else:
+ cost.append(None)
+ return cost
+
+ def _box_cost(self, box_pred, box_target, instance_reg_weights):
+ bs = box_pred.shape[0]
+ cost = []
+ for i in range(bs):
+ if len(box_target[i]) > 0:
+ cost.append(
+ torch.sum(
+ torch.abs(box_pred[i, :, None] - box_target[i][None])
+ * instance_reg_weights[i][None]
+ * box_pred.new_tensor(self.reg_weights),
+ dim=-1,
+ )
+ * self.box_weight
+ )
+ else:
+ cost.append(None)
+ return cost
+
+ def get_dn_anchors(self, cls_target, box_target, gt_instance_id=None):
+ if self.num_dn_groups <= 0:
+ return None
+ if self.num_temp_dn_groups <= 0:
+ gt_instance_id = None
+
+ if self.max_dn_gt > 0:
+ cls_target = [x[: self.max_dn_gt] for x in cls_target]
+ box_target = [x[: self.max_dn_gt] for x in box_target]
+ if gt_instance_id is not None:
+ gt_instance_id = [x[: self.max_dn_gt] for x in gt_instance_id]
+
+ max_dn_gt = max([len(x) for x in cls_target])
+ if max_dn_gt == 0:
+ return None
+ cls_target = torch.stack(
+ [
+ F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1)
+ for x in cls_target
+ ]
+ )
+ box_target = self.encode_reg_target(box_target, cls_target.device)
+ box_target = torch.stack(
+ [F.pad(x, (0, 0, 0, max_dn_gt - x.shape[0])) for x in box_target]
+ )
+ box_target = torch.where(
+ cls_target[..., None] == -1, box_target.new_tensor(0), box_target
+ )
+ if gt_instance_id is not None:
+ gt_instance_id = torch.stack(
+ [
+ F.pad(x, (0, max_dn_gt - x.shape[0]), value=-1)
+ for x in gt_instance_id
+ ]
+ )
+
+ bs, num_gt, state_dims = box_target.shape
+ if self.num_dn_groups > 1:
+ cls_target = cls_target.tile(self.num_dn_groups, 1)
+ box_target = box_target.tile(self.num_dn_groups, 1, 1)
+ if gt_instance_id is not None:
+ gt_instance_id = gt_instance_id.tile(self.num_dn_groups, 1)
+
+ noise = torch.rand_like(box_target) * 2 - 1
+ noise *= box_target.new_tensor(self.dn_noise_scale)
+ dn_anchor = box_target + noise
+ if self.add_neg_dn:
+ noise_neg = torch.rand_like(box_target) + 1
+ flag = torch.where(
+ torch.rand_like(box_target) > 0.5,
+ noise_neg.new_tensor(1),
+ noise_neg.new_tensor(-1),
+ )
+ noise_neg *= flag
+ noise_neg *= box_target.new_tensor(self.dn_noise_scale)
+ dn_anchor = torch.cat([dn_anchor, box_target + noise_neg], dim=1)
+ num_gt *= 2
+
+ box_cost = self._box_cost(
+ dn_anchor, box_target, torch.ones_like(box_target)
+ )
+ dn_box_target = torch.zeros_like(dn_anchor)
+ dn_cls_target = -torch.ones_like(cls_target) * 3
+ if gt_instance_id is not None:
+ dn_id_target = -torch.ones_like(gt_instance_id)
+ if self.add_neg_dn:
+ dn_cls_target = torch.cat([dn_cls_target, dn_cls_target], dim=1)
+ if gt_instance_id is not None:
+ dn_id_target = torch.cat([dn_id_target, dn_id_target], dim=1)
+
+ for i in range(dn_anchor.shape[0]):
+ cost = box_cost[i].cpu().numpy()
+ anchor_idx, gt_idx = linear_sum_assignment(cost)
+ anchor_idx = dn_anchor.new_tensor(anchor_idx, dtype=torch.int64)
+ gt_idx = dn_anchor.new_tensor(gt_idx, dtype=torch.int64)
+ dn_box_target[i, anchor_idx] = box_target[i, gt_idx]
+ dn_cls_target[i, anchor_idx] = cls_target[i, gt_idx]
+ if gt_instance_id is not None:
+ dn_id_target[i, anchor_idx] = gt_instance_id[i, gt_idx]
+ dn_anchor = (
+ dn_anchor.reshape(self.num_dn_groups, bs, num_gt, state_dims)
+ .permute(1, 0, 2, 3)
+ .flatten(1, 2)
+ )
+ dn_box_target = (
+ dn_box_target.reshape(self.num_dn_groups, bs, num_gt, state_dims)
+ .permute(1, 0, 2, 3)
+ .flatten(1, 2)
+ )
+ dn_cls_target = (
+ dn_cls_target.reshape(self.num_dn_groups, bs, num_gt)
+ .permute(1, 0, 2)
+ .flatten(1)
+ )
+ if gt_instance_id is not None:
+ dn_id_target = (
+ dn_id_target.reshape(self.num_dn_groups, bs, num_gt)
+ .permute(1, 0, 2)
+ .flatten(1)
+ )
+ else:
+ dn_id_target = None
+ valid_mask = dn_cls_target >= 0
+ if self.add_neg_dn:
+ cls_target = (
+ torch.cat([cls_target, cls_target], dim=1)
+ .reshape(self.num_dn_groups, bs, num_gt)
+ .permute(1, 0, 2)
+ .flatten(1)
+ )
+ valid_mask = torch.logical_or(
+ valid_mask, ((cls_target >= 0) & (dn_cls_target == -3))
+ ) # valid denotes the items is not from pad.
+ attn_mask = dn_box_target.new_ones(
+ num_gt * self.num_dn_groups, num_gt * self.num_dn_groups
+ )
+ for i in range(self.num_dn_groups):
+ start = num_gt * i
+ end = start + num_gt
+ attn_mask[start:end, start:end] = 0
+ attn_mask = attn_mask == 1
+ dn_cls_target = dn_cls_target.long()
+ return (
+ dn_anchor,
+ dn_box_target,
+ dn_cls_target,
+ attn_mask,
+ valid_mask,
+ dn_id_target,
+ )
+
+ def update_dn(
+ self,
+ instance_feature,
+ anchor,
+ dn_reg_target,
+ dn_cls_target,
+ valid_mask,
+ dn_id_target,
+ num_noraml_anchor,
+ temporal_valid_mask,
+ ):
+ bs, num_anchor = instance_feature.shape[:2]
+ if temporal_valid_mask is None:
+ self.dn_metas = None
+ if self.dn_metas is None or num_noraml_anchor >= num_anchor:
+ return (
+ instance_feature,
+ anchor,
+ dn_reg_target,
+ dn_cls_target,
+ valid_mask,
+ dn_id_target,
+ )
+
+ # split instance_feature and anchor into non-dn and dn
+ num_dn = num_anchor - num_noraml_anchor
+ dn_instance_feature = instance_feature[:, -num_dn:]
+ dn_anchor = anchor[:, -num_dn:]
+ instance_feature = instance_feature[:, :num_noraml_anchor]
+ anchor = anchor[:, :num_noraml_anchor]
+
+ # reshape all dn metas from (bs,num_all_dn,xxx)
+ # to (bs, dn_group, num_dn_per_group, xxx)
+ num_dn_groups = self.num_dn_groups
+ num_dn = num_dn // num_dn_groups
+ dn_feat = dn_instance_feature.reshape(bs, num_dn_groups, num_dn, -1)
+ dn_anchor = dn_anchor.reshape(bs, num_dn_groups, num_dn, -1)
+ dn_reg_target = dn_reg_target.reshape(bs, num_dn_groups, num_dn, -1)
+ dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_dn)
+ valid_mask = valid_mask.reshape(bs, num_dn_groups, num_dn)
+ if dn_id_target is not None:
+ dn_id = dn_id_target.reshape(bs, num_dn_groups, num_dn)
+
+ # update temp_dn_metas by instance_id
+ temp_dn_feat = self.dn_metas["dn_instance_feature"]
+ _, num_temp_dn_groups, num_temp_dn = temp_dn_feat.shape[:3]
+ temp_dn_id = self.dn_metas["dn_id_target"]
+
+ # bs, num_temp_dn_groups, num_temp_dn, num_dn
+ match = temp_dn_id[..., None] == dn_id[:, :num_temp_dn_groups, None]
+ temp_reg_target = (
+ match[..., None] * dn_reg_target[:, :num_temp_dn_groups, None]
+ ).sum(dim=3)
+ temp_cls_target = torch.where(
+ torch.all(torch.logical_not(match), dim=-1),
+ self.dn_metas["dn_cls_target"].new_tensor(-1),
+ self.dn_metas["dn_cls_target"],
+ )
+ temp_valid_mask = self.dn_metas["valid_mask"]
+ temp_dn_anchor = self.dn_metas["dn_anchor"]
+
+ # handle the misalignment the length of temp_dn to dn caused by the
+ # change of num_gt, then concat the temp_dn and dn
+ temp_dn_metas = [
+ temp_dn_feat,
+ temp_dn_anchor,
+ temp_reg_target,
+ temp_cls_target,
+ temp_valid_mask,
+ temp_dn_id,
+ ]
+ dn_metas = [
+ dn_feat,
+ dn_anchor,
+ dn_reg_target,
+ dn_cls_target,
+ valid_mask,
+ dn_id,
+ ]
+ output = []
+ for i, (temp_meta, meta) in enumerate(zip(temp_dn_metas, dn_metas)):
+ if num_temp_dn < num_dn:
+ pad = (0, num_dn - num_temp_dn)
+ if temp_meta.dim() == 4:
+ pad = (0, 0) + pad
+ else:
+ assert temp_meta.dim() == 3
+ temp_meta = F.pad(temp_meta, pad, value=0)
+ else:
+ temp_meta = temp_meta[:, :, :num_dn]
+ mask = temporal_valid_mask[:, None, None]
+ if meta.dim() == 4:
+ mask = mask.unsqueeze(dim=-1)
+ temp_meta = torch.where(
+ mask, temp_meta, meta[:, :num_temp_dn_groups]
+ )
+ meta = torch.cat([temp_meta, meta[:, num_temp_dn_groups:]], dim=1)
+ meta = meta.flatten(1, 2)
+ output.append(meta)
+ output[0] = torch.cat([instance_feature, output[0]], dim=1)
+ output[1] = torch.cat([anchor, output[1]], dim=1)
+ return output
+
+ def cache_dn(
+ self,
+ dn_instance_feature,
+ dn_anchor,
+ dn_cls_target,
+ valid_mask,
+ dn_id_target,
+ ):
+ if self.num_temp_dn_groups < 0:
+ return
+ num_dn_groups = self.num_dn_groups
+ bs, num_dn = dn_instance_feature.shape[:2]
+ num_temp_dn = num_dn // num_dn_groups
+ temp_group_mask = (
+ torch.randperm(num_dn_groups) < self.num_temp_dn_groups
+ )
+ temp_group_mask = temp_group_mask.to(device=dn_anchor.device)
+ dn_instance_feature = dn_instance_feature.detach().reshape(
+ bs, num_dn_groups, num_temp_dn, -1
+ )[:, temp_group_mask]
+ dn_anchor = dn_anchor.detach().reshape(
+ bs, num_dn_groups, num_temp_dn, -1
+ )[:, temp_group_mask]
+ dn_cls_target = dn_cls_target.reshape(bs, num_dn_groups, num_temp_dn)[
+ :, temp_group_mask
+ ]
+ valid_mask = valid_mask.reshape(bs, num_dn_groups, num_temp_dn)[
+ :, temp_group_mask
+ ]
+ if dn_id_target is not None:
+ dn_id_target = dn_id_target.reshape(
+ bs, num_dn_groups, num_temp_dn
+ )[:, temp_group_mask]
+ self.dn_metas = dict(
+ dn_instance_feature=dn_instance_feature,
+ dn_anchor=dn_anchor,
+ dn_cls_target=dn_cls_target,
+ valid_mask=valid_mask,
+ dn_id_target=dn_id_target,
+ )
diff --git a/projects/mmdet3d_plugin/models/grid_mask.py b/projects/mmdet3d_plugin/models/grid_mask.py
new file mode 100644
index 0000000..ead0caa
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/grid_mask.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from PIL import Image
+
+
+class Grid(object):
+ def __init__(
+ self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0
+ ):
+ self.use_h = use_h
+ self.use_w = use_w
+ self.rotate = rotate
+ self.offset = offset
+ self.ratio = ratio
+ self.mode = mode
+ self.st_prob = prob
+ self.prob = prob
+
+ def set_prob(self, epoch, max_epoch):
+ self.prob = self.st_prob * epoch / max_epoch
+
+ def __call__(self, img, label):
+ if np.random.rand() > self.prob:
+ return img, label
+ h = img.size(1)
+ w = img.size(2)
+ self.d1 = 2
+ self.d2 = min(h, w)
+ hh = int(1.5 * h)
+ ww = int(1.5 * w)
+ d = np.random.randint(self.d1, self.d2)
+ if self.ratio == 1:
+ self.l = np.random.randint(1, d)
+ else:
+ self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
+ mask = np.ones((hh, ww), np.float32)
+ st_h = np.random.randint(d)
+ st_w = np.random.randint(d)
+ if self.use_h:
+ for i in range(hh // d):
+ s = d * i + st_h
+ t = min(s + self.l, hh)
+ mask[s:t, :] *= 0
+ if self.use_w:
+ for i in range(ww // d):
+ s = d * i + st_w
+ t = min(s + self.l, ww)
+ mask[:, s:t] *= 0
+
+ r = np.random.randint(self.rotate)
+ mask = Image.fromarray(np.uint8(mask))
+ mask = mask.rotate(r)
+ mask = np.asarray(mask)
+ mask = mask[
+ (hh - h) // 2 : (hh - h) // 2 + h,
+ (ww - w) // 2 : (ww - w) // 2 + w,
+ ]
+
+ mask = torch.from_numpy(mask).float()
+ if self.mode == 1:
+ mask = 1 - mask
+
+ mask = mask.expand_as(img)
+ if self.offset:
+ offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
+ offset = (1 - mask) * offset
+ img = img * mask + offset
+ else:
+ img = img * mask
+
+ return img, label
+
+
+class GridMask(nn.Module):
+ def __init__(
+ self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0
+ ):
+ super(GridMask, self).__init__()
+ self.use_h = use_h
+ self.use_w = use_w
+ self.rotate = rotate
+ self.offset = offset
+ self.ratio = ratio
+ self.mode = mode
+ self.st_prob = prob
+ self.prob = prob
+
+ def set_prob(self, epoch, max_epoch):
+ self.prob = self.st_prob * epoch / max_epoch # + 1.#0.5
+
+ def forward(self, x):
+ if np.random.rand() > self.prob or not self.training:
+ return x
+ n, c, h, w = x.size()
+ x = x.view(-1, h, w)
+ hh = int(1.5 * h)
+ ww = int(1.5 * w)
+ d = np.random.randint(2, h)
+ self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
+ mask = np.ones((hh, ww), np.float32)
+ st_h = np.random.randint(d)
+ st_w = np.random.randint(d)
+ if self.use_h:
+ for i in range(hh // d):
+ s = d * i + st_h
+ t = min(s + self.l, hh)
+ mask[s:t, :] *= 0
+ if self.use_w:
+ for i in range(ww // d):
+ s = d * i + st_w
+ t = min(s + self.l, ww)
+ mask[:, s:t] *= 0
+
+ r = np.random.randint(self.rotate)
+ mask = Image.fromarray(np.uint8(mask))
+ mask = mask.rotate(r)
+ mask = np.asarray(mask)
+ mask = mask[
+ (hh - h) // 2 : (hh - h) // 2 + h,
+ (ww - w) // 2 : (ww - w) // 2 + w,
+ ]
+
+ mask = torch.from_numpy(mask.copy()).float().cuda()
+ if self.mode == 1:
+ mask = 1 - mask
+ mask = mask.expand_as(x)
+ if self.offset:
+ offset = (
+ torch.from_numpy(2 * (np.random.rand(h, w) - 0.5))
+ .float()
+ .cuda()
+ )
+ x = x * mask + offset * (1 - mask)
+ else:
+ x = x * mask
+
+ return x.view(n, c, h, w)
diff --git a/projects/mmdet3d_plugin/models/gt_sparse_drive_head.py b/projects/mmdet3d_plugin/models/gt_sparse_drive_head.py
new file mode 100644
index 0000000..bdb80d9
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/gt_sparse_drive_head.py
@@ -0,0 +1,378 @@
+from typing import List
+
+import torch
+
+from mmcv.runner import BaseModule
+from mmdet.models import HEADS, build_head
+
+from projects.mmdet3d_plugin.core.box3d import SIN_YAW, COS_YAW, YAW
+
+
+@HEADS.register_module()
+class GTSparseDriveHead(BaseModule):
+ """SparseDrive head that uses GT boxes and optionally GT map as oracle inputs.
+
+ Bypasses the detection transformer and feeds GT bounding boxes (and
+ optionally GT map polylines) directly into the motion/planning head
+ for oracle prediction evaluation.
+
+ The det_head config is still required to provide:
+ - anchor_encoder (SparseBox3DEncoder)
+ - instance_bank (InstanceBank for temporal mask and anchor_handler)
+ - sampler (for the indices interface expected by MotionTarget)
+
+ If map_head config is provided, its anchor_encoder and instance_bank
+ are used to encode GT map polylines into the map_output structure
+ expected by MotionPlanningHead (enables cross_gnn map attention).
+
+ Args:
+ task_config: Task flags (with_det, with_map, with_motion_plan).
+ det_head: Config for Sparse4DHead (used for sub-modules only).
+ map_head: Config for map Sparse4DHead (used for sub-modules only).
+ If provided, GT map is injected into motion/planning head.
+ motion_plan_head: Config for MotionPlanningHead.
+ num_classes: Number of detection classes.
+ """
+
+ def __init__(
+ self,
+ task_config: dict,
+ det_head: dict = None,
+ map_head: dict = None,
+ motion_plan_head: dict = None,
+ num_classes: int = 10,
+ num_map_classes: int = 3,
+ init_cfg=None,
+ **kwargs,
+ ):
+ super(GTSparseDriveHead, self).__init__(init_cfg)
+ self.task_config = task_config
+ self.num_classes = num_classes
+
+ assert det_head is not None, (
+ "det_head config is required to provide anchor_encoder "
+ "and instance_bank sub-modules."
+ )
+ self.det_head = build_head(det_head)
+ # Freeze all det_head parameters: we only use its sub-modules
+ # (anchor_encoder, instance_bank) as non-trainable components.
+ # This prevents DDP from complaining about unused parameters.
+ for p in self.det_head.parameters():
+ p.requires_grad_(False)
+
+ self.num_map_classes = num_map_classes
+ if map_head is not None:
+ self.map_head = build_head(map_head)
+ # Freeze map_head parameters: only anchor_encoder and
+ # instance_bank sub-modules are used (non-trainable).
+ for p in self.map_head.parameters():
+ p.requires_grad_(False)
+
+ assert motion_plan_head is not None
+ self.motion_plan_head = build_head(motion_plan_head)
+
+ def init_weights(self):
+ self.det_head.init_weights()
+ if hasattr(self, "map_head"):
+ self.map_head.init_weights()
+ self.motion_plan_head.init_weights()
+
+ # ------------------------------------------------------------------ #
+ # Forward
+ # ------------------------------------------------------------------ #
+
+ def forward(self, feature_maps, metas: dict):
+ batch_size = len(metas["img_metas"])
+ device = (
+ feature_maps[0].device
+ if isinstance(feature_maps[0], torch.Tensor)
+ else feature_maps[0][0].device
+ )
+
+ # 1. Update instance_bank temporal mask.
+ # We discard the returned bank features; we only need self.mask.
+ self.det_head.instance_bank.get(batch_size, metas)
+
+ # 2. Build det_output from GT boxes.
+ det_output = self._build_gt_det_output(metas, batch_size, feature_maps)
+
+ # 3. Cache GT features/anchors in the bank for temporal tracking.
+ self.det_head.instance_bank.cache(
+ det_output["instance_feature"],
+ det_output["prediction"][-1],
+ det_output["classification"][-1],
+ metas,
+ feature_maps,
+ )
+
+ # 4. Build GT map_output if map_head sub-modules are available.
+ if hasattr(self, "map_head"):
+ map_output = self._build_gt_map_output(metas, batch_size, device)
+ else:
+ map_output = None
+
+ # 5. Forward motion/planning head.
+ motion_output, planning_output = self.motion_plan_head(
+ det_output,
+ map_output,
+ feature_maps,
+ metas,
+ self.det_head.anchor_encoder,
+ self.det_head.instance_bank.mask,
+ self.det_head.instance_bank.anchor_handler,
+ )
+
+ return det_output, map_output, motion_output, planning_output
+
+ # ------------------------------------------------------------------ #
+ # GT det_output construction helpers
+ # ------------------------------------------------------------------ #
+
+ @staticmethod
+ def _encode_gt_boxes(boxes: torch.Tensor) -> torch.Tensor:
+ """Convert decoded 9-dim GT boxes to encoded 11-dim anchor format.
+
+ Input: (N, >=7) [..., x, y, z, w, l, h, yaw, vx, vy]
+ Output: (N, 11) [x, y, z, log_w, log_l, log_h,
+ sin_yaw, cos_yaw, vx, vy, 0]
+ """
+ xyz = boxes[:, :3]
+ wlh = boxes[:, 3:6].clamp(min=1e-3).log()
+ sin_yaw = torch.sin(boxes[:, YAW : YAW + 1])
+ cos_yaw = torch.cos(boxes[:, YAW : YAW + 1])
+ vel = boxes[:, 7:9] if boxes.shape[-1] >= 9 else boxes.new_zeros(len(boxes), 2)
+ vz = boxes.new_zeros(len(boxes), 1)
+ return torch.cat([xyz, wlh, sin_yaw, cos_yaw, vel, vz], dim=-1)
+
+ def _build_gt_det_output(
+ self, metas: dict, batch_size: int, feature_maps
+ ) -> dict:
+ """Build the det_output dict populated with GT boxes."""
+ if isinstance(feature_maps[0], torch.Tensor):
+ device = feature_maps[0].device
+ else:
+ device = feature_maps[0][0].device
+
+ num_anchor = self.det_head.instance_bank.num_anchor
+ embed_dims = self.det_head.instance_bank.embed_dims
+
+ gt_bboxes = metas["gt_bboxes_3d"] # list[Tensor(N_i, 9)]
+ gt_labels = metas["gt_labels_3d"] # list[Tensor(N_i,)]
+
+ anchors = torch.zeros(batch_size, num_anchor, 11, device=device)
+ # Very-negative logits → near-zero confidence after sigmoid (padding).
+ cls_logits = anchors.new_full(
+ (batch_size, num_anchor, self.num_classes), -100.0
+ )
+
+ for i in range(batch_size):
+ bboxes_i = gt_bboxes[i]
+ labels_i = gt_labels[i]
+
+ if not isinstance(bboxes_i, torch.Tensor):
+ bboxes_i = torch.tensor(
+ bboxes_i, device=device, dtype=torch.float32
+ )
+ else:
+ bboxes_i = bboxes_i.to(device=device, dtype=torch.float32)
+
+ if not isinstance(labels_i, torch.Tensor):
+ labels_i = torch.tensor(
+ labels_i, device=device, dtype=torch.long
+ )
+ else:
+ labels_i = labels_i.to(device=device)
+
+ N_i = len(bboxes_i)
+ if N_i == 0:
+ continue
+ N_i = min(N_i, num_anchor)
+ bboxes_i = bboxes_i[:N_i]
+ labels_i = labels_i[:N_i]
+
+ anchors[i, :N_i] = self._encode_gt_boxes(bboxes_i)
+
+ # High-positive logit at GT class, very-negative elsewhere.
+ cls_logits[i, :N_i] = -100.0
+ cls_logits[i, torch.arange(N_i, device=device), labels_i] = 100.0
+
+ # Anchor embeddings from the det_head's encoder.
+ anchor_embed = self.det_head.anchor_encoder(anchors)
+
+ # Instance features initialised to zero; the motion GNN refines them.
+ instance_feature = torch.zeros(
+ batch_size, num_anchor, embed_dims, device=device
+ )
+
+ # GT instance IDs for temporal tracking in InstanceQueue.
+ instance_id = self._get_gt_instance_ids(
+ metas, batch_size, num_anchor, device
+ )
+
+ return {
+ "instance_feature": instance_feature,
+ "anchor_embed": anchor_embed,
+ "classification": [cls_logits],
+ "prediction": [anchors],
+ "quality": [None],
+ "instance_id": instance_id,
+ }
+
+ def _build_gt_map_output(
+ self, metas: dict, batch_size: int, device
+ ) -> dict:
+ """Build the map_output dict populated with GT map polylines.
+
+ GT map pts are encoded via the map_head's anchor_encoder into the
+ same anchor_embed space that MotionPlanningHead's cross_gnn expects.
+ Classification logits are set to +100 at the GT class so that the
+ confidence-based top-k selection in MotionPlanningHead picks the
+ true map elements.
+
+ Args:
+ metas: Batch metas containing 'gt_map_labels' and 'gt_map_pts'.
+ gt_map_labels: list[Tensor(M_i,)]
+ gt_map_pts: list[Tensor(M_i, num_sample, 2)] (test)
+ or list[Tensor(M_i, num_perms, num_sample, 2)]
+ (train, when VectorizeMap has permute=True).
+ """
+ num_anchor = self.map_head.instance_bank.num_anchor # 100
+ embed_dims = self.map_head.instance_bank.embed_dims # 256
+ num_sample_x2 = self.map_head.anchor_encoder.input_dims # num_sample * 2
+
+ gt_map_labels = metas["gt_map_labels"] # list[Tensor(M_i,)]
+ gt_map_pts = metas["gt_map_pts"] # list[Tensor(...)]
+
+ predictions = torch.zeros(
+ batch_size, num_anchor, num_sample_x2, device=device
+ )
+ cls_logits = torch.full(
+ (batch_size, num_anchor, self.num_map_classes), -100.0, device=device
+ )
+
+ for i in range(batch_size):
+ map_pts_i = gt_map_pts[i]
+ map_labels_i = gt_map_labels[i]
+
+ if not isinstance(map_pts_i, torch.Tensor):
+ map_pts_i = torch.tensor(
+ map_pts_i, device=device, dtype=torch.float32
+ )
+ else:
+ map_pts_i = map_pts_i.to(device=device, dtype=torch.float32)
+
+ if not isinstance(map_labels_i, torch.Tensor):
+ map_labels_i = torch.tensor(
+ map_labels_i, device=device, dtype=torch.long
+ )
+ else:
+ map_labels_i = map_labels_i.to(device=device)
+
+ M_i = len(map_pts_i)
+ if M_i == 0:
+ continue
+ M_i = min(M_i, num_anchor)
+ map_pts_i = map_pts_i[:M_i]
+ map_labels_i = map_labels_i[:M_i]
+
+ # Train pipeline uses permute=True: (M, num_perms, num_sample, 2).
+ # Take the first permutation (canonical polyline direction).
+ if map_pts_i.dim() == 4:
+ map_pts_i = map_pts_i[:, 0] # (M, num_sample, 2)
+
+ predictions[i, :M_i] = map_pts_i.reshape(M_i, -1)
+ cls_logits[i, :M_i] = -100.0
+ cls_logits[i, torch.arange(M_i, device=device), map_labels_i] = 100.0
+
+ # Encode GT map point coordinates into positional embeddings.
+ anchor_embed = self.map_head.anchor_encoder(predictions)
+
+ # Instance features initialised to zero; cross_gnn will attend to
+ # anchor_embed (position) rather than learned instance content.
+ instance_feature = torch.zeros(
+ batch_size, num_anchor, embed_dims, device=device
+ )
+
+ return {
+ "instance_feature": instance_feature,
+ "anchor_embed": anchor_embed,
+ "classification": [cls_logits],
+ "prediction": [predictions],
+ }
+
+ @staticmethod
+ def _get_gt_instance_ids(
+ metas: dict,
+ batch_size: int,
+ num_anchor: int,
+ device,
+ ) -> torch.Tensor:
+ """Read GT instance IDs from metas, padded to (bs, num_anchor)."""
+ instance_id = torch.full(
+ (batch_size, num_anchor), -1, dtype=torch.long, device=device
+ )
+ for i, img_meta in enumerate(metas["img_metas"]):
+ gt_id = img_meta.get("instance_id", None)
+ if gt_id is None:
+ continue
+ if not isinstance(gt_id, torch.Tensor):
+ gt_id = torch.tensor(gt_id, dtype=torch.long, device=device)
+ else:
+ gt_id = gt_id.to(device=device)
+ N_i = min(len(gt_id), num_anchor)
+ instance_id[i, :N_i] = gt_id[:N_i]
+ return instance_id
+
+ # ------------------------------------------------------------------ #
+ # Loss
+ # ------------------------------------------------------------------ #
+
+ def loss(self, model_outs, data):
+ _, _, motion_output, planning_output = model_outs
+
+ motion_loss_cache = dict(
+ indices=self._build_identity_indices(data),
+ )
+ return self.motion_plan_head.loss(
+ motion_output, planning_output, data, motion_loss_cache
+ )
+
+ def _build_identity_indices(self, data) -> List:
+ """Identity matching: pred anchor i → GT agent i for each batch item.
+
+ The motion sampler uses these indices to assign GT future trajectories
+ to predicted agent slots. With GT-as-detection, agent i directly
+ corresponds to GT agent i, so both pred_idx and target_idx are [0..N_i).
+ """
+ gt_labels = data["gt_labels_3d"] # list of (N_i,) tensors
+ device = next(self.parameters()).device
+ indices = []
+ for i in range(len(gt_labels)):
+ N_i = len(gt_labels[i])
+ if N_i == 0:
+ indices.append([None, None])
+ else:
+ idx = torch.arange(N_i, device=device, dtype=torch.long)
+ indices.append([idx, idx])
+ return indices
+
+ # ------------------------------------------------------------------ #
+ # Post-process
+ # ------------------------------------------------------------------ #
+
+ def post_process(self, model_outs, data):
+ det_output, _, motion_output, planning_output = model_outs
+
+ det_result = self.det_head.post_process(det_output)
+ motion_result, planning_result = self.motion_plan_head.post_process(
+ det_output, motion_output, planning_output, data
+ )
+
+ batch_size = len(motion_result)
+ results = [dict() for _ in range(batch_size)]
+ for i in range(batch_size):
+ results[i].update(det_result[i])
+ results[i].update(motion_result[i])
+ results[i].update(planning_result[i])
+
+ return results
diff --git a/projects/mmdet3d_plugin/models/instance_bank.py b/projects/mmdet3d_plugin/models/instance_bank.py
new file mode 100644
index 0000000..49a0353
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/instance_bank.py
@@ -0,0 +1,265 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+import numpy as np
+
+from mmcv.utils import build_from_cfg
+from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
+
+__all__ = ["InstanceBank"]
+
+
+def topk(confidence, k, *inputs):
+ bs, N = confidence.shape[:2]
+ confidence, indices = torch.topk(confidence, k, dim=1)
+ indices = (
+ indices + torch.arange(bs, device=indices.device)[:, None] * N
+ ).reshape(-1)
+ outputs = []
+ for input in inputs:
+ outputs.append(input.flatten(end_dim=1)[indices].reshape(bs, k, -1))
+ return confidence, outputs
+
+
+@PLUGIN_LAYERS.register_module()
+class InstanceBank(nn.Module):
+ def __init__(
+ self,
+ num_anchor,
+ embed_dims,
+ anchor,
+ anchor_handler=None,
+ num_temp_instances=0,
+ default_time_interval=0.5,
+ confidence_decay=0.6,
+ anchor_grad=True,
+ feat_grad=True,
+ max_time_interval=2,
+ ):
+ super(InstanceBank, self).__init__()
+ self.embed_dims = embed_dims
+ self.num_temp_instances = num_temp_instances
+ self.default_time_interval = default_time_interval
+ self.confidence_decay = confidence_decay
+ self.max_time_interval = max_time_interval
+
+ if anchor_handler is not None:
+ anchor_handler = build_from_cfg(anchor_handler, PLUGIN_LAYERS)
+ assert hasattr(anchor_handler, "anchor_projection")
+ self.anchor_handler = anchor_handler
+ if isinstance(anchor, str):
+ anchor = np.load(anchor)
+ elif isinstance(anchor, (list, tuple)):
+ anchor = np.array(anchor)
+ if len(anchor.shape) == 3: # for map
+ anchor = anchor.reshape(anchor.shape[0], -1)
+ self.num_anchor = min(len(anchor), num_anchor)
+ anchor = anchor[:num_anchor]
+ self.anchor = nn.Parameter(
+ torch.tensor(anchor, dtype=torch.float32),
+ requires_grad=anchor_grad,
+ )
+ self.anchor_init = anchor
+ self.instance_feature = nn.Parameter(
+ torch.zeros([self.anchor.shape[0], self.embed_dims]),
+ requires_grad=feat_grad,
+ )
+ self.reset()
+
+ def init_weight(self):
+ self.anchor.data = self.anchor.data.new_tensor(self.anchor_init)
+ if self.instance_feature.requires_grad:
+ torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1)
+
+ def reset(self):
+ self.cached_feature = None
+ self.cached_anchor = None
+ self.metas = None
+ self.mask = None
+ self.confidence = None
+ self.temp_confidence = None
+ self.instance_id = None
+ self.prev_id = 0
+
+ def get(self, batch_size, metas=None, dn_metas=None):
+ instance_feature = torch.tile(
+ self.instance_feature[None], (batch_size, 1, 1)
+ )
+ anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))
+
+ if (
+ self.cached_anchor is not None
+ and batch_size == self.cached_anchor.shape[0]
+ ):
+ history_time = self.metas["timestamp"]
+ time_interval = metas["timestamp"] - history_time
+ time_interval = time_interval.to(dtype=instance_feature.dtype)
+ self.mask = torch.abs(time_interval) <= self.max_time_interval
+
+ if self.anchor_handler is not None:
+ T_temp2cur = self.cached_anchor.new_tensor(
+ np.stack(
+ [
+ x["T_global_inv"]
+ @ self.metas["img_metas"][i]["T_global"]
+ for i, x in enumerate(metas["img_metas"])
+ ]
+ )
+ )
+ self.cached_anchor = self.anchor_handler.anchor_projection(
+ self.cached_anchor,
+ [T_temp2cur],
+ time_intervals=[-time_interval],
+ )[0]
+
+ if (
+ self.anchor_handler is not None
+ and dn_metas is not None
+ and batch_size == dn_metas["dn_anchor"].shape[0]
+ ):
+ num_dn_group, num_dn = dn_metas["dn_anchor"].shape[1:3]
+ dn_anchor = self.anchor_handler.anchor_projection(
+ dn_metas["dn_anchor"].flatten(1, 2),
+ [T_temp2cur],
+ time_intervals=[-time_interval],
+ )[0]
+ dn_metas["dn_anchor"] = dn_anchor.reshape(
+ batch_size, num_dn_group, num_dn, -1
+ )
+ time_interval = torch.where(
+ torch.logical_and(time_interval != 0, self.mask),
+ time_interval,
+ time_interval.new_tensor(self.default_time_interval),
+ )
+ else:
+ self.reset()
+ time_interval = instance_feature.new_tensor(
+ [self.default_time_interval] * batch_size
+ )
+
+ return (
+ instance_feature,
+ anchor,
+ self.cached_feature,
+ self.cached_anchor,
+ time_interval,
+ )
+
+ def update(self, instance_feature, anchor, confidence,
+ cached_feature_override=None, cached_anchor_override=None):
+ if self.cached_feature is None:
+ return instance_feature, anchor
+
+ num_dn = 0
+ if instance_feature.shape[1] > self.num_anchor:
+ num_dn = instance_feature.shape[1] - self.num_anchor
+ dn_instance_feature = instance_feature[:, -num_dn:]
+ dn_anchor = anchor[:, -num_dn:]
+ instance_feature = instance_feature[:, : self.num_anchor]
+ anchor = anchor[:, : self.num_anchor]
+ confidence = confidence[:, : self.num_anchor]
+
+ cached_feat = cached_feature_override if cached_feature_override is not None \
+ else self.cached_feature
+ cached_anch = cached_anchor_override if cached_anchor_override is not None \
+ else self.cached_anchor
+
+ N = self.num_anchor - self.num_temp_instances
+ confidence = confidence.max(dim=-1).values
+ _, (selected_feature, selected_anchor) = topk(
+ confidence, N, instance_feature, anchor
+ )
+ selected_feature = torch.cat(
+ [cached_feat, selected_feature], dim=1
+ )
+ selected_anchor = torch.cat(
+ [cached_anch, selected_anchor], dim=1
+ )
+ instance_feature = torch.where(
+ self.mask[:, None, None], selected_feature, instance_feature
+ )
+ anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor)
+ self.confidence = torch.where(
+ self.mask[:, None],
+ self.confidence,
+ self.confidence.new_tensor(0)
+ )
+ if self.instance_id is not None:
+ self.instance_id = torch.where(
+ self.mask[:, None],
+ self.instance_id,
+ self.instance_id.new_tensor(-1),
+ )
+
+ if num_dn > 0:
+ instance_feature = torch.cat(
+ [instance_feature, dn_instance_feature], dim=1
+ )
+ anchor = torch.cat([anchor, dn_anchor], dim=1)
+ return instance_feature, anchor
+
+ def cache(
+ self,
+ instance_feature,
+ anchor,
+ confidence,
+ metas=None,
+ feature_maps=None,
+ ):
+ if self.num_temp_instances <= 0:
+ return
+ instance_feature = instance_feature.detach()
+ anchor = anchor.detach()
+ confidence = confidence.detach()
+
+ self.metas = metas
+ confidence = confidence.max(dim=-1).values.sigmoid()
+ if self.confidence is not None:
+ confidence[:, : self.num_temp_instances] = torch.maximum(
+ self.confidence * self.confidence_decay,
+ confidence[:, : self.num_temp_instances],
+ )
+ self.temp_confidence = confidence
+
+ (
+ self.confidence,
+ (self.cached_feature, self.cached_anchor),
+ ) = topk(confidence, self.num_temp_instances, instance_feature, anchor)
+
+ def get_instance_id(self, confidence, anchor=None, threshold=None):
+ confidence = confidence.max(dim=-1).values.sigmoid()
+ instance_id = confidence.new_full(confidence.shape, -1).long()
+
+ if (
+ self.instance_id is not None
+ and self.instance_id.shape[0] == instance_id.shape[0]
+ ):
+ instance_id[:, : self.instance_id.shape[1]] = self.instance_id
+
+ mask = instance_id < 0
+ if threshold is not None:
+ mask = mask & (confidence >= threshold)
+ num_new_instance = mask.sum()
+ new_ids = torch.arange(num_new_instance).to(instance_id) + self.prev_id
+ instance_id[torch.where(mask)] = new_ids
+ self.prev_id += num_new_instance
+ self.update_instance_id(instance_id, confidence)
+ return instance_id
+
+ def update_instance_id(self, instance_id=None, confidence=None):
+ if self.temp_confidence is None:
+ if confidence.dim() == 3: # bs, num_anchor, num_cls
+ temp_conf = confidence.max(dim=-1).values
+ else: # bs, num_anchor
+ temp_conf = confidence
+ else:
+ temp_conf = self.temp_confidence
+ instance_id = topk(temp_conf, self.num_temp_instances, instance_id)[1][
+ 0
+ ]
+ instance_id = instance_id.squeeze(dim=-1)
+ self.instance_id = F.pad(
+ instance_id,
+ (0, self.num_anchor - self.num_temp_instances),
+ value=-1,
+ )
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/map/__init__.py b/projects/mmdet3d_plugin/models/map/__init__.py
new file mode 100644
index 0000000..4fa1d7d
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/map/__init__.py
@@ -0,0 +1,9 @@
+from .decoder import SparsePoint3DDecoder
+from .target import SparsePoint3DTarget, HungarianLinesAssigner
+from .match_cost import LinesL1Cost, MapQueriesCost
+from .loss import LinesL1Loss, SparseLineLoss
+from .map_blocks import (
+ SparsePoint3DRefinementModule,
+ SparsePoint3DKeyPointsGenerator,
+ SparsePoint3DEncoder,
+)
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/map/decoder.py b/projects/mmdet3d_plugin/models/map/decoder.py
new file mode 100644
index 0000000..b9564c9
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/map/decoder.py
@@ -0,0 +1,53 @@
+from typing import Optional, List
+
+import torch
+
+from mmdet.core.bbox.builder import BBOX_CODERS
+
+
+@BBOX_CODERS.register_module()
+class SparsePoint3DDecoder(object):
+ def __init__(
+ self,
+ coords_dim: int = 2,
+ score_threshold: Optional[float] = None,
+ ):
+ super(SparsePoint3DDecoder, self).__init__()
+ self.score_threshold = score_threshold
+ self.coords_dim = coords_dim
+
+ def decode(
+ self,
+ cls_scores,
+ pts_preds,
+ instance_id=None,
+ quality=None,
+ output_idx=-1,
+ ):
+ bs, num_pred, num_cls = cls_scores[-1].shape
+ cls_scores = cls_scores[-1].sigmoid()
+ pts_preds = pts_preds[-1].reshape(bs, num_pred, -1, self.coords_dim)
+ cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
+ num_pred, dim=1
+ )
+ cls_ids = indices % num_cls
+ if self.score_threshold is not None:
+ mask = cls_scores >= self.score_threshold
+ output = []
+ for i in range(bs):
+ category_ids = cls_ids[i]
+ scores = cls_scores[i]
+ pts = pts_preds[i, indices[i] // num_cls]
+ if self.score_threshold is not None:
+ category_ids = category_ids[mask[i]]
+ scores = scores[mask[i]]
+ pts = pts[mask[i]]
+
+ output.append(
+ {
+ "vectors": [vec.detach().cpu().numpy() for vec in pts],
+ "scores": scores.detach().cpu().numpy(),
+ "labels": category_ids.detach().cpu().numpy(),
+ }
+ )
+ return output
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/map/loss.py b/projects/mmdet3d_plugin/models/map/loss.py
new file mode 100644
index 0000000..0edd8be
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/map/loss.py
@@ -0,0 +1,120 @@
+import torch
+import torch.nn as nn
+
+from mmcv.utils import build_from_cfg
+from mmdet.models.builder import LOSSES
+from mmdet.models.losses import l1_loss, smooth_l1_loss
+
+
+@LOSSES.register_module()
+class LinesL1Loss(nn.Module):
+
+ def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5):
+ """
+ L1 loss. The same as the smooth L1 loss
+ Args:
+ reduction (str, optional): The method to reduce the loss.
+ Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of loss.
+ """
+
+ super().__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.beta = beta
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+ Args:
+ pred (torch.Tensor): The prediction.
+ shape: [bs, ...]
+ target (torch.Tensor): The learning target of the prediction.
+ shape: [bs, ...]
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ it's useful when the predictions are not all valid.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+
+ if self.beta > 0:
+ loss = smooth_l1_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta)
+
+ else:
+ loss = l1_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+
+ num_points = pred.shape[-1] // 2
+ loss = loss / num_points
+
+ return loss*self.loss_weight
+
+
+@LOSSES.register_module()
+class SparseLineLoss(nn.Module):
+ def __init__(
+ self,
+ loss_line,
+ num_sample=20,
+ roi_size=(30, 60),
+ ):
+ super().__init__()
+
+ def build(cfg, registry):
+ if cfg is None:
+ return None
+ return build_from_cfg(cfg, registry)
+
+ self.loss_line = build(loss_line, LOSSES)
+ self.num_sample = num_sample
+ self.roi_size = roi_size
+
+ def forward(
+ self,
+ line,
+ line_target,
+ weight=None,
+ avg_factor=None,
+ prefix="",
+ suffix="",
+ **kwargs,
+ ):
+
+ output = {}
+ line = self.normalize_line(line)
+ line_target = self.normalize_line(line_target)
+ line_loss = self.loss_line(
+ line, line_target, weight=weight, avg_factor=avg_factor
+ )
+ output[f"{prefix}loss_line{suffix}"] = line_loss
+
+ return output
+
+ def normalize_line(self, line):
+ if line.shape[0] == 0:
+ return line
+
+ line = line.view(line.shape[:-1] + (self.num_sample, -1))
+
+ origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2])
+ line = line - origin
+
+ # transform from range [0, 1] to (0, 1)
+ eps = 1e-5
+ norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps
+ line = line / norm
+ line = line.flatten(-2, -1)
+
+ return line
diff --git a/projects/mmdet3d_plugin/models/map/map_blocks.py b/projects/mmdet3d_plugin/models/map/map_blocks.py
new file mode 100644
index 0000000..1fb87bd
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/map/map_blocks.py
@@ -0,0 +1,199 @@
+from typing import Optional, List, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+
+from mmcv.cnn import Linear, Scale, bias_init_with_prob
+from mmcv.runner.base_module import Sequential, BaseModule
+from mmcv.cnn import xavier_init
+from mmcv.cnn.bricks.registry import (
+ PLUGIN_LAYERS,
+ POSITIONAL_ENCODING,
+)
+
+from ..blocks import linear_relu_ln
+
+
+@POSITIONAL_ENCODING.register_module()
+class SparsePoint3DEncoder(BaseModule):
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_sample: int = 20,
+ coords_dim: int = 2,
+ ):
+ super(SparsePoint3DEncoder, self).__init__()
+ self.embed_dims = embed_dims
+ self.input_dims = num_sample * coords_dim
+ def embedding_layer(input_dims):
+ return nn.Sequential(*linear_relu_ln(embed_dims, 1, 2, input_dims))
+
+ self.pos_fc = embedding_layer(self.input_dims)
+
+ def forward(self, anchor: torch.Tensor):
+ pos_feat = self.pos_fc(anchor)
+ return pos_feat
+
+
+@PLUGIN_LAYERS.register_module()
+class SparsePoint3DRefinementModule(BaseModule):
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_sample: int = 20,
+ coords_dim: int = 2,
+ num_cls: int = 3,
+ with_cls_branch: bool = True,
+ ):
+ super(SparsePoint3DRefinementModule, self).__init__()
+ self.embed_dims = embed_dims
+ self.num_sample = num_sample
+ self.output_dim = num_sample * coords_dim
+ self.num_cls = num_cls
+
+ self.layers = nn.Sequential(
+ *linear_relu_ln(embed_dims, 2, 2),
+ Linear(self.embed_dims, self.output_dim),
+ Scale([1.0] * self.output_dim),
+ )
+
+ self.with_cls_branch = with_cls_branch
+ if with_cls_branch:
+ self.cls_layers = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 2),
+ Linear(self.embed_dims, self.num_cls),
+ )
+
+ def init_weight(self):
+ if self.with_cls_branch:
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.cls_layers[-1].bias, bias_init)
+
+ def forward(
+ self,
+ instance_feature: torch.Tensor,
+ anchor: torch.Tensor,
+ anchor_embed: torch.Tensor,
+ time_interval: torch.Tensor = 1.0,
+ return_cls=True,
+ ):
+ output = self.layers(instance_feature + anchor_embed)
+ output = output + anchor
+ if return_cls:
+ assert self.with_cls_branch, "Without classification layers !!!"
+ cls = self.cls_layers(instance_feature) ## NOTE anchor embed?
+ else:
+ cls = None
+ qt = None
+ return output, cls, qt, None
+
+
+@PLUGIN_LAYERS.register_module()
+class SparsePoint3DKeyPointsGenerator(BaseModule):
+ def __init__(
+ self,
+ embed_dims: int = 256,
+ num_sample: int = 20,
+ num_learnable_pts: int = 0,
+ fix_height: Tuple = (0,),
+ ground_height: int = 0,
+ ):
+ super(SparsePoint3DKeyPointsGenerator, self).__init__()
+ self.embed_dims = embed_dims
+ self.num_sample = num_sample
+ self.num_learnable_pts = num_learnable_pts
+ self.num_pts = num_sample * len(fix_height) * num_learnable_pts
+ if self.num_learnable_pts > 0:
+ self.learnable_fc = Linear(self.embed_dims, self.num_pts * 2)
+
+ self.fix_height = np.array(fix_height)
+ self.ground_height = ground_height
+
+ def init_weight(self):
+ if self.num_learnable_pts > 0:
+ xavier_init(self.learnable_fc, distribution="uniform", bias=0.0)
+
+ def forward(
+ self,
+ anchor,
+ instance_feature=None,
+ T_cur2temp_list=None,
+ cur_timestamp=None,
+ temp_timestamps=None,
+ ):
+ assert self.num_learnable_pts > 0, 'No learnable pts'
+ bs, num_anchor, _ = anchor.shape
+ key_points = anchor.view(bs, num_anchor, self.num_sample, -1)
+ offset = (
+ self.learnable_fc(instance_feature)
+ .reshape(bs, num_anchor, self.num_sample, len(self.fix_height), self.num_learnable_pts, 2)
+ )
+ key_points = offset + key_points[..., None, None, :]
+ key_points = torch.cat(
+ [
+ key_points,
+ key_points.new_full(key_points.shape[:-1]+(1,), fill_value=self.ground_height),
+ ],
+ dim=-1,
+ )
+ fix_height = key_points.new_tensor(self.fix_height)
+ height_offset = key_points.new_zeros([len(fix_height), 2])
+ height_offset = torch.cat([height_offset, fix_height[:,None]], dim=-1)
+ key_points = key_points + height_offset[None, None, None, :, None]
+ key_points = key_points.flatten(2, 4)
+ if (
+ cur_timestamp is None
+ or temp_timestamps is None
+ or T_cur2temp_list is None
+ or len(temp_timestamps) == 0
+ ):
+ return key_points
+
+ temp_key_points_list = []
+ for i, t_time in enumerate(temp_timestamps):
+ temp_key_points = key_points
+ T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype)
+ temp_key_points = (
+ T_cur2temp[:, None, None, :3]
+ @ torch.cat(
+ [
+ temp_key_points,
+ torch.ones_like(temp_key_points[..., :1]),
+ ],
+ dim=-1,
+ ).unsqueeze(-1)
+ )
+ temp_key_points = temp_key_points.squeeze(-1)
+ temp_key_points_list.append(temp_key_points)
+ return key_points, temp_key_points_list
+
+ # @staticmethod
+ def anchor_projection(
+ self,
+ anchor,
+ T_src2dst_list,
+ src_timestamp=None,
+ dst_timestamps=None,
+ time_intervals=None,
+ ):
+ dst_anchors = []
+ for i in range(len(T_src2dst_list)):
+ dst_anchor = anchor.clone()
+ bs, num_anchor, _ = anchor.shape
+ dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(1, 2)
+ T_src2dst = torch.unsqueeze(
+ T_src2dst_list[i].to(dtype=anchor.dtype), dim=1
+ )
+
+ dst_anchor = (
+ torch.matmul(
+ T_src2dst[..., :2, :2], dst_anchor[..., None]
+ ).squeeze(dim=-1)
+ + T_src2dst[..., :2, 3]
+ )
+
+ dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(2, 3)
+ dst_anchors.append(dst_anchor)
+ return dst_anchors
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/map/match_cost.py b/projects/mmdet3d_plugin/models/map/match_cost.py
new file mode 100644
index 0000000..3c19a71
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/map/match_cost.py
@@ -0,0 +1,104 @@
+import torch
+from mmdet.core.bbox.match_costs.builder import MATCH_COST
+from mmdet.core.bbox.match_costs import build_match_cost
+from torch.nn.functional import smooth_l1_loss
+
+
+@MATCH_COST.register_module()
+class LinesL1Cost(object):
+ """LinesL1Cost.
+ Args:
+ weight (int | float, optional): loss_weight
+ """
+
+ def __init__(self, weight=1.0, beta=0.0, permute=False):
+ self.weight = weight
+ self.permute = permute
+ self.beta = beta
+
+ def __call__(self, lines_pred, gt_lines, **kwargs):
+ """
+ Args:
+ lines_pred (Tensor): predicted normalized lines:
+ [num_query, 2*num_points]
+ gt_lines (Tensor): Ground truth lines
+ [num_gt, 2*num_points] or [num_gt, num_permute, 2*num_points]
+ Returns:
+ torch.Tensor: reg_cost value with weight
+ shape [num_pred, num_gt]
+ """
+ if self.permute:
+ assert len(gt_lines.shape) == 3
+ else:
+ assert len(gt_lines.shape) == 2
+
+ num_pred, num_gt = len(lines_pred), len(gt_lines)
+ if self.permute:
+ # permute-invarint labels
+ gt_lines = gt_lines.flatten(0, 1) # (num_gt*num_permute, 2*num_pts)
+
+ num_pts = lines_pred.shape[-1]//2
+
+ if self.beta > 0:
+ lines_pred = lines_pred.unsqueeze(1).repeat(1, len(gt_lines), 1)
+ gt_lines = gt_lines.unsqueeze(0).repeat(num_pred, 1, 1)
+ dist_mat = smooth_l1_loss(lines_pred, gt_lines, reduction='none', beta=self.beta).sum(-1)
+
+ else:
+ dist_mat = torch.cdist(lines_pred, gt_lines, p=1)
+
+ dist_mat = dist_mat / num_pts
+
+ if self.permute:
+ # dist_mat: (num_pred, num_gt*num_permute)
+ dist_mat = dist_mat.view(num_pred, num_gt, -1) # (num_pred, num_gt, num_permute)
+ dist_mat, gt_permute_index = torch.min(dist_mat, 2)
+ return dist_mat * self.weight, gt_permute_index
+
+ return dist_mat * self.weight
+
+
+@MATCH_COST.register_module()
+class MapQueriesCost(object):
+
+ def __init__(self, cls_cost, reg_cost, iou_cost=None):
+
+ self.cls_cost = build_match_cost(cls_cost)
+ self.reg_cost = build_match_cost(reg_cost)
+
+ self.iou_cost = None
+ if iou_cost is not None:
+ self.iou_cost = build_match_cost(iou_cost)
+
+ def __call__(self, preds: dict, gts: dict, ignore_cls_cost: bool):
+
+ # classification and bboxcost.
+ cls_cost = self.cls_cost(preds['scores'], gts['labels'])
+
+ # regression cost
+ regkwargs = {}
+ if 'masks' in preds and 'masks' in gts:
+ assert isinstance(self.reg_cost, DynamicLinesCost), ' Issues!!'
+ regkwargs = {
+ 'masks_pred': preds['masks'],
+ 'masks_gt': gts['masks'],
+ }
+
+ reg_cost = self.reg_cost(preds['lines'], gts['lines'], **regkwargs)
+ if self.reg_cost.permute:
+ reg_cost, gt_permute_idx = reg_cost
+
+ # weighted sum of above three costs
+ if ignore_cls_cost:
+ cost = reg_cost
+ else:
+ cost = cls_cost + reg_cost
+
+ # Iou
+ if self.iou_cost is not None:
+ iou_cost = self.iou_cost(preds['lines'],gts['lines'])
+ cost += iou_cost
+
+ if self.reg_cost.permute:
+ return cost, gt_permute_idx
+ return cost
diff --git a/projects/mmdet3d_plugin/models/map/target.py b/projects/mmdet3d_plugin/models/map/target.py
new file mode 100644
index 0000000..6695fce
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/map/target.py
@@ -0,0 +1,167 @@
+import torch
+import numpy as np
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+
+from mmdet.core.bbox.builder import (BBOX_SAMPLERS, BBOX_ASSIGNERS)
+from mmdet.core.bbox.match_costs import build_match_cost
+from mmdet.core import (build_assigner, build_sampler)
+from mmdet.core.bbox.assigners import (AssignResult, BaseAssigner)
+
+from ..base_target import BaseTargetWithDenoising
+
+
+@BBOX_SAMPLERS.register_module()
+class SparsePoint3DTarget(BaseTargetWithDenoising):
+ def __init__(
+ self,
+ assigner=None,
+ num_dn_groups=0,
+ dn_noise_scale=0.5,
+ max_dn_gt=32,
+ add_neg_dn=True,
+ num_temp_dn_groups=0,
+ num_cls=3,
+ num_sample=20,
+ roi_size=(30, 60),
+ ):
+ super(SparsePoint3DTarget, self).__init__(
+ num_dn_groups, num_temp_dn_groups
+ )
+ self.assigner = build_assigner(assigner)
+ self.dn_noise_scale = dn_noise_scale
+ self.max_dn_gt = max_dn_gt
+ self.add_neg_dn = add_neg_dn
+
+ self.num_cls = num_cls
+ self.num_sample = num_sample
+ self.roi_size = roi_size
+
+ def sample(
+ self,
+ cls_preds,
+ pts_preds,
+ cls_targets,
+ pts_targets,
+ ):
+ pts_targets = [x.flatten(2, 3) if len(x.shape)==4 else x for x in pts_targets]
+ indices = []
+ for(cls_pred, pts_pred, cls_target, pts_target) in zip(
+ cls_preds, pts_preds, cls_targets, pts_targets
+ ):
+ # normalize to (0, 1)
+ pts_pred = self.normalize_line(pts_pred)
+ pts_target = self.normalize_line(pts_target)
+ preds=dict(lines=pts_pred, scores=cls_pred)
+ gts=dict(lines=pts_target, labels=cls_target)
+ indice = self.assigner.assign(preds, gts)
+ indices.append(indice)
+
+ bs, num_pred, num_cls = cls_preds.shape
+ output_cls_target = cls_targets[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls
+ output_box_target = pts_preds.new_zeros(pts_preds.shape)
+ output_reg_weights = pts_preds.new_zeros(pts_preds.shape)
+ for i, (pred_idx, target_idx, gt_permute_index) in enumerate(indices):
+ if len(cls_targets[i]) == 0:
+ continue
+ permute_idx = gt_permute_index[pred_idx, target_idx]
+ output_cls_target[i, pred_idx] = cls_targets[i][target_idx]
+ output_box_target[i, pred_idx] = pts_targets[i][target_idx, permute_idx]
+ output_reg_weights[i, pred_idx] = 1
+
+ return output_cls_target, output_box_target, output_reg_weights
+
+ def normalize_line(self, line):
+ if line.shape[0] == 0:
+ return line
+
+ line = line.view(line.shape[:-1] + (self.num_sample, -1))
+
+ origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2])
+ line = line - origin
+
+ # transform from range [0, 1] to (0, 1)
+ eps = 1e-5
+ norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps
+ line = line / norm
+ line = line.flatten(-2, -1)
+
+ return line
+
+
+@BBOX_ASSIGNERS.register_module()
+class HungarianLinesAssigner(BaseAssigner):
+ """
+ Computes one-to-one matching between predictions and ground truth.
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The costs are weighted sum of three components:
+ classification cost and regression L1 cost. The
+ targets don't include the no_object, so generally there are more
+ predictions than targets. After the one-to-one matching, the un-matched
+ are treated as backgrounds. Thus each query prediction will be assigned
+ with `0` or a positive integer indicating the ground truth index:
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+ Args:
+ cls_weight (int | float, optional): The scale factor for classification
+ cost. Default 1.0.
+ bbox_weight (int | float, optional): The scale factor for regression
+ L1 cost. Default 1.0.
+ """
+
+ def __init__(self, cost=dict, **kwargs):
+ self.cost = build_match_cost(cost)
+
+ def assign(self,
+ preds: dict,
+ gts: dict,
+ ignore_cls_cost=False,
+ gt_bboxes_ignore=None,
+ eps=1e-7):
+ """
+ Computes one-to-one matching based on the weighted costs.
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+ Args:
+ lines_pred (Tensor): predicted normalized lines:
+ [num_query, num_points, 2]
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+
+ lines_gt (Tensor): Ground truth lines
+ [num_gt, num_points, 2].
+ labels_gt (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert gt_bboxes_ignore is None, \
+ 'Only case when gt_bboxes_ignore is None is supported.'
+
+ num_gts, num_lines = gts['lines'].size(0), preds['lines'].size(0)
+ if num_gts == 0 or num_lines == 0:
+ return None, None, None
+
+ # compute the weighted costs
+ gt_permute_idx = None # (num_preds, num_gts)
+ if self.cost.reg_cost.permute:
+ cost, gt_permute_idx = self.cost(preds, gts, ignore_cls_cost)
+ else:
+ cost = self.cost(preds, gts, ignore_cls_cost)
+
+ # do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu().numpy()
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ return matched_row_inds, matched_col_inds, gt_permute_idx
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/motion/__init__.py b/projects/mmdet3d_plugin/models/motion/__init__.py
new file mode 100644
index 0000000..11b2f42
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/motion/__init__.py
@@ -0,0 +1,6 @@
+from .motion_planning_head import MotionPlanningHead
+from .kinematic_motion_planning_head import KinematicMotionPlanningHead
+from .motion_blocks import MotionPlanningRefinementModule
+from .instance_queue import InstanceQueue
+from .target import MotionTarget, PlanningTarget
+from .decoder import SparseBox3DMotionDecoder, HierarchicalPlanningDecoder
diff --git a/projects/mmdet3d_plugin/models/motion/decoder.py b/projects/mmdet3d_plugin/models/motion/decoder.py
new file mode 100644
index 0000000..fa673cf
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/motion/decoder.py
@@ -0,0 +1,329 @@
+from typing import Optional
+
+import numpy as np
+import torch
+
+from mmdet.core.bbox.builder import BBOX_CODERS
+
+from projects.mmdet3d_plugin.core.box3d import *
+from projects.mmdet3d_plugin.models.detection3d.decoder import *
+from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners
+
+
+@BBOX_CODERS.register_module()
+class SparseBox3DMotionDecoder(SparseBox3DDecoder):
+ def __init__(self):
+ super(SparseBox3DMotionDecoder, self).__init__()
+
+ def decode(
+ self,
+ cls_scores,
+ box_preds,
+ instance_id=None,
+ quality=None,
+ motion_output=None,
+ output_idx=-1,
+ ):
+ squeeze_cls = instance_id is not None
+
+ cls_scores = cls_scores[output_idx].sigmoid()
+
+ if squeeze_cls:
+ cls_scores, cls_ids = cls_scores.max(dim=-1)
+ cls_scores = cls_scores.unsqueeze(dim=-1)
+
+ box_preds = box_preds[output_idx]
+ bs, num_pred, num_cls = cls_scores.shape
+ cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
+ self.num_output, dim=1, sorted=self.sorted
+ )
+ if not squeeze_cls:
+ cls_ids = indices % num_cls
+ if self.score_threshold is not None:
+ mask = cls_scores >= self.score_threshold
+
+ if quality[output_idx] is None:
+ quality = None
+ if quality is not None:
+ centerness = quality[output_idx][..., CNS]
+ centerness = torch.gather(centerness, 1, indices // num_cls)
+ cls_scores_origin = cls_scores.clone()
+ cls_scores *= centerness.sigmoid()
+ cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)
+ if not squeeze_cls:
+ cls_ids = torch.gather(cls_ids, 1, idx)
+ if self.score_threshold is not None:
+ mask = torch.gather(mask, 1, idx)
+ indices = torch.gather(indices, 1, idx)
+
+ output = []
+ anchor_queue = motion_output["anchor_queue"]
+ anchor_queue = torch.stack(anchor_queue, dim=2)
+ period = motion_output["period"]
+
+ for i in range(bs):
+ category_ids = cls_ids[i]
+ if squeeze_cls:
+ category_ids = category_ids[indices[i]]
+ scores = cls_scores[i]
+ box = box_preds[i, indices[i] // num_cls]
+ if self.score_threshold is not None:
+ category_ids = category_ids[mask[i]]
+ scores = scores[mask[i]]
+ box = box[mask[i]]
+ if quality is not None:
+ scores_origin = cls_scores_origin[i]
+ if self.score_threshold is not None:
+ scores_origin = scores_origin[mask[i]]
+
+ box = decode_box(box)
+ trajs = motion_output["prediction"][-1]
+ traj_cls = motion_output["classification"][-1].sigmoid()
+ traj = trajs[i, indices[i] // num_cls]
+ traj_cls = traj_cls[i, indices[i] // num_cls]
+ if self.score_threshold is not None:
+ traj = traj[mask[i]]
+ traj_cls = traj_cls[mask[i]]
+ traj = traj.cumsum(dim=-2) + box[:, None, None, :2]
+ output.append(
+ {
+ "trajs_3d": traj.cpu(),
+ "trajs_score": traj_cls.cpu()
+ }
+ )
+
+ temp_anchor = anchor_queue[i, indices[i] // num_cls]
+ temp_period = period[i, indices[i] // num_cls]
+ if self.score_threshold is not None:
+ temp_anchor = temp_anchor[mask[i]]
+ temp_period = temp_period[mask[i]]
+ num_pred, queue_len = temp_anchor.shape[:2]
+ temp_anchor = temp_anchor.flatten(0, 1)
+ temp_anchor = decode_box(temp_anchor)
+ temp_anchor = temp_anchor.reshape([num_pred, queue_len, box.shape[-1]])
+ output[-1]['anchor_queue'] = temp_anchor.cpu()
+ output[-1]['period'] = temp_period.cpu()
+
+ return output
+
+
+@BBOX_CODERS.register_module()
+class HierarchicalPlanningDecoder(object):
+ def __init__(
+ self,
+ ego_fut_ts,
+ ego_fut_mode,
+ use_rescore=False,
+ ):
+ super(HierarchicalPlanningDecoder, self).__init__()
+ self.ego_fut_ts = ego_fut_ts
+ self.ego_fut_mode = ego_fut_mode
+ self.use_rescore = use_rescore
+
+ def decode(
+ self,
+ det_output,
+ motion_output,
+ planning_output,
+ data,
+ ):
+ classification = planning_output['classification'][-1]
+ prediction = planning_output['prediction'][-1]
+ bs = classification.shape[0]
+ classification = classification.reshape(bs, 3, self.ego_fut_mode)
+ prediction = prediction.reshape(bs, 3, self.ego_fut_mode, self.ego_fut_ts, 2).cumsum(dim=-2)
+ classification, final_planning = self.select(det_output, motion_output, classification, prediction, data)
+ anchor_queue = planning_output["anchor_queue"]
+ anchor_queue = torch.stack(anchor_queue, dim=2)
+ period = planning_output["period"]
+ output = []
+ for i, (cls, pred) in enumerate(zip(classification, prediction)):
+ output.append(
+ {
+ "planning_score": cls.sigmoid().cpu(),
+ "planning": pred.cpu(),
+ "final_planning": final_planning[i].cpu(),
+ "ego_period": period[i].cpu(),
+ "ego_anchor_queue": decode_box(anchor_queue[i]).cpu(),
+ }
+ )
+
+ return output
+
+ def select(
+ self,
+ det_output,
+ motion_output,
+ plan_cls,
+ plan_reg,
+ data,
+ ):
+ det_classification = det_output["classification"][-1].sigmoid()
+ det_anchors = det_output["prediction"][-1]
+ det_confidence = det_classification.max(dim=-1).values
+ motion_cls = motion_output["classification"][-1].sigmoid()
+ motion_reg = motion_output["prediction"][-1]
+
+ # cmd select
+ bs = motion_cls.shape[0]
+ bs_indices = torch.arange(bs, device=motion_cls.device)
+ cmd = data['gt_ego_fut_cmd'].argmax(dim=-1)
+ plan_cls_full = plan_cls.detach().clone()
+ plan_cls = plan_cls[bs_indices, cmd]
+ plan_reg = plan_reg[bs_indices, cmd]
+
+ # rescore
+ if self.use_rescore:
+ plan_cls = self.rescore(
+ plan_cls,
+ plan_reg,
+ motion_cls,
+ motion_reg,
+ det_anchors,
+ det_confidence,
+ )
+ plan_cls_full[bs_indices, cmd] = plan_cls
+ mode_idx = plan_cls.argmax(dim=-1)
+ final_planning = plan_reg[bs_indices, mode_idx]
+ return plan_cls_full, final_planning
+
+ def rescore(
+ self,
+ plan_cls,
+ plan_reg,
+ motion_cls,
+ motion_reg,
+ det_anchors,
+ det_confidence,
+ score_thresh=0.5,
+ static_dis_thresh=0.5,
+ dim_scale=1.1,
+ num_motion_mode=1,
+ offset=0.5,
+ ):
+
+ def cat_with_zero(traj):
+ zeros = traj.new_zeros(traj.shape[:-2] + (1, 2))
+ traj_cat = torch.cat([zeros, traj], dim=-2)
+ return traj_cat
+
+ def get_yaw(traj, start_yaw=np.pi/2):
+ yaw = traj.new_zeros(traj.shape[:-1])
+ yaw[..., 1:-1] = torch.atan2(
+ traj[..., 2:, 1] - traj[..., :-2, 1],
+ traj[..., 2:, 0] - traj[..., :-2, 0],
+ )
+ yaw[..., -1] = torch.atan2(
+ traj[..., -1, 1] - traj[..., -2, 1],
+ traj[..., -1, 0] - traj[..., -2, 0],
+ )
+ yaw[..., 0] = start_yaw
+ # for static object, estimated future yaw would be unstable
+ start = traj[..., 0, :]
+ end = traj[..., -1, :]
+ dist = torch.linalg.norm(end - start, dim=-1)
+ mask = dist < static_dis_thresh
+ start_yaw = yaw[..., 0].unsqueeze(-1)
+ yaw = torch.where(
+ mask.unsqueeze(-1),
+ start_yaw,
+ yaw,
+ )
+ return yaw.unsqueeze(-1)
+
+ ## ego
+ bs = plan_reg.shape[0]
+ plan_reg_cat = cat_with_zero(plan_reg)
+ ego_box = det_anchors.new_zeros(bs, self.ego_fut_mode, self.ego_fut_ts + 1, 7)
+ ego_box[..., [X, Y]] = plan_reg_cat
+ ego_box[..., [W, L, H]] = ego_box.new_tensor([4.08, 1.73, 1.56]) * dim_scale
+ ego_box[..., [YAW]] = get_yaw(plan_reg_cat)
+
+ ## motion
+ motion_reg = motion_reg[..., :self.ego_fut_ts, :].cumsum(-2)
+ motion_reg = cat_with_zero(motion_reg) + det_anchors[:, :, None, None, :2]
+ _, motion_mode_idx = torch.topk(motion_cls, num_motion_mode, dim=-1)
+ motion_mode_idx = motion_mode_idx[..., None, None].repeat(1, 1, 1, self.ego_fut_ts + 1, 2)
+ motion_reg = torch.gather(motion_reg, 2, motion_mode_idx)
+
+ motion_box = motion_reg.new_zeros(motion_reg.shape[:-1] + (7,))
+ motion_box[..., [X, Y]] = motion_reg
+ motion_box[..., [W, L, H]] = det_anchors[..., None, None, [W, L, H]].exp()
+ box_yaw = torch.atan2(
+ det_anchors[..., SIN_YAW],
+ det_anchors[..., COS_YAW],
+ )
+ motion_box[..., [YAW]] = get_yaw(motion_reg, box_yaw.unsqueeze(-1))
+
+ filter_mask = det_confidence < score_thresh
+ motion_box[filter_mask] = 1e6
+
+ ego_box = ego_box[..., 1:, :]
+ motion_box = motion_box[..., 1:, :]
+
+ bs, num_ego_mode, ts, _ = ego_box.shape
+ bs, num_anchor, num_motion_mode, ts, _ = motion_box.shape
+ ego_box = ego_box[:, None, None].repeat(1, num_anchor, num_motion_mode, 1, 1, 1).flatten(0, -2)
+ motion_box = motion_box.unsqueeze(3).repeat(1, 1, 1, num_ego_mode, 1, 1).flatten(0, -2)
+
+ ego_box[0] += offset * torch.cos(ego_box[6])
+ ego_box[1] += offset * torch.sin(ego_box[6])
+ col = check_collision(ego_box, motion_box)
+ col = col.reshape(bs, num_anchor, num_motion_mode, num_ego_mode, ts).permute(0, 3, 1, 2, 4)
+ col = col.flatten(2, -1).any(dim=-1)
+ all_col = col.all(dim=-1)
+ col[all_col] = False # for case that all modes collide, no need to rescore
+ score_offset = col.float() * -999
+ plan_cls = plan_cls + score_offset
+ return plan_cls
+
+
+def check_collision(boxes1, boxes2):
+ '''
+ A rough check for collision detection:
+ check if any corner point of boxes1 is inside boxes2 and vice versa.
+
+ boxes1: tensor with shape [N, 7], [x, y, z, w, l, h, yaw]
+ boxes2: tensor with shape [N, 7]
+ '''
+ col_1 = corners_in_box(boxes1.clone(), boxes2.clone())
+ col_2 = corners_in_box(boxes2.clone(), boxes1.clone())
+ collision = torch.logical_or(col_1, col_2)
+
+ return collision
+
+def corners_in_box(boxes1, boxes2):
+ if boxes1.shape[0] == 0 or boxes2.shape[0] == 0:
+ return False
+
+ boxes1_yaw = boxes1[:, 6].clone()
+ boxes1_loc = boxes1[:, :3].clone()
+ cos_yaw = torch.cos(-boxes1_yaw)
+ sin_yaw = torch.sin(-boxes1_yaw)
+ rot_mat_T = torch.stack(
+ [
+ torch.stack([cos_yaw, sin_yaw]),
+ torch.stack([-sin_yaw, cos_yaw]),
+ ]
+ )
+ # translate and rotate boxes
+ boxes1[:, :3] = boxes1[:, :3] - boxes1_loc
+ boxes1[:, :2] = torch.einsum('ij,jki->ik', boxes1[:, :2], rot_mat_T)
+ boxes1[:, 6] = boxes1[:, 6] - boxes1_yaw
+
+ boxes2[:, :3] = boxes2[:, :3] - boxes1_loc
+ boxes2[:, :2] = torch.einsum('ij,jki->ik', boxes2[:, :2], rot_mat_T)
+ boxes2[:, 6] = boxes2[:, 6] - boxes1_yaw
+
+ corners_box2 = box3d_to_corners(boxes2)[:, [0, 3, 7, 4], :2]
+ corners_box2 = torch.from_numpy(corners_box2).to(boxes2.device)
+ H = boxes1[:, [3]]
+ W = boxes1[:, [4]]
+
+ collision = torch.logical_and(
+ torch.logical_and(corners_box2[..., 0] <= H / 2, corners_box2[..., 0] >= -H / 2),
+ torch.logical_and(corners_box2[..., 1] <= W / 2, corners_box2[..., 1] >= -W / 2),
+ )
+ collision = collision.any(dim=-1)
+
+ return collision
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/motion/instance_queue.py b/projects/mmdet3d_plugin/models/motion/instance_queue.py
new file mode 100644
index 0000000..1905d77
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/motion/instance_queue.py
@@ -0,0 +1,213 @@
+import copy
+import torch
+from torch import nn
+import torch.nn.functional as F
+import numpy as np
+
+from mmcv.utils import build_from_cfg
+from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
+
+from projects.mmdet3d_plugin.ops import feature_maps_format
+from projects.mmdet3d_plugin.core.box3d import *
+
+
+@PLUGIN_LAYERS.register_module()
+class InstanceQueue(nn.Module):
+ def __init__(
+ self,
+ embed_dims,
+ queue_length=0,
+ tracking_threshold=0,
+ feature_map_scale=None,
+ ):
+ super(InstanceQueue, self).__init__()
+ self.embed_dims = embed_dims
+ self.queue_length = queue_length
+ self.tracking_threshold = tracking_threshold
+
+ kernel_size = tuple([int(x / 2) for x in feature_map_scale])
+ self.ego_feature_encoder = nn.Sequential(
+ nn.Conv2d(embed_dims, embed_dims, 3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(embed_dims),
+ nn.Conv2d(embed_dims, embed_dims, 3, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(embed_dims),
+ nn.ReLU(),
+ nn.AvgPool2d(kernel_size),
+ )
+ self.ego_anchor = nn.Parameter(
+ torch.tensor([[0, 0.5, -1.84 + 1.56/2, np.log(4.08), np.log(1.73), np.log(1.56), 1, 0, 0, 0, 0],], dtype=torch.float32),
+ requires_grad=False,
+ )
+
+ self.reset()
+
+ def reset(self):
+ self.metas = None
+ self.prev_instance_id = None
+ self.prev_confidence = None
+ self.period = None
+ self.instance_feature_queue = []
+ self.anchor_queue = []
+ self.prev_ego_status = None
+ self.ego_period = None
+ self.ego_feature_queue = []
+ self.ego_anchor_queue = []
+
+ def get(
+ self,
+ det_output,
+ feature_maps,
+ metas,
+ batch_size,
+ mask,
+ anchor_handler,
+ ):
+ if (
+ self.period is not None
+ and batch_size == self.period.shape[0]
+ ):
+ if anchor_handler is not None:
+ T_temp2cur = feature_maps[0].new_tensor(
+ np.stack(
+ [
+ x["T_global_inv"]
+ @ self.metas["img_metas"][i]["T_global"]
+ for i, x in enumerate(metas["img_metas"])
+ ]
+ )
+ )
+ for i in range(len(self.anchor_queue)):
+ temp_anchor = self.anchor_queue[i]
+ temp_anchor = anchor_handler.anchor_projection(
+ temp_anchor,
+ [T_temp2cur],
+ )[0]
+ self.anchor_queue[i] = temp_anchor
+ for i in range(len(self.ego_anchor_queue)):
+ temp_anchor = self.ego_anchor_queue[i]
+ temp_anchor = anchor_handler.anchor_projection(
+ temp_anchor,
+ [T_temp2cur],
+ )[0]
+ self.ego_anchor_queue[i] = temp_anchor
+ else:
+ self.reset()
+
+ self.prepare_motion(det_output, mask)
+ ego_feature, ego_anchor = self.prepare_planning(feature_maps, mask, batch_size)
+
+ # temporal
+ temp_instance_feature = torch.stack(self.instance_feature_queue, dim=2)
+ temp_anchor = torch.stack(self.anchor_queue, dim=2)
+ temp_ego_feature = torch.stack(self.ego_feature_queue, dim=2)
+ temp_ego_anchor = torch.stack(self.ego_anchor_queue, dim=2)
+
+ period = torch.cat([self.period, self.ego_period], dim=1)
+ temp_instance_feature = torch.cat([temp_instance_feature, temp_ego_feature], dim=1)
+ temp_anchor = torch.cat([temp_anchor, temp_ego_anchor], dim=1)
+ num_agent = temp_anchor.shape[1]
+
+ temp_mask = torch.arange(len(self.anchor_queue), 0, -1, device=temp_anchor.device)
+ temp_mask = temp_mask[None, None].repeat((batch_size, num_agent, 1))
+ temp_mask = torch.gt(temp_mask, period[..., None])
+
+ return ego_feature, ego_anchor, temp_instance_feature, temp_anchor, temp_mask
+
+ def prepare_motion(
+ self,
+ det_output,
+ mask,
+ ):
+ instance_feature = det_output["instance_feature"]
+ det_anchors = det_output["prediction"][-1]
+
+ if self.period == None:
+ self.period = instance_feature.new_zeros(instance_feature.shape[:2]).long()
+ else:
+ instance_id = det_output['instance_id']
+ prev_instance_id = self.prev_instance_id
+ match = instance_id[..., None] == prev_instance_id[:, None]
+ if self.tracking_threshold > 0:
+ temp_mask = self.prev_confidence > self.tracking_threshold
+ match = match * temp_mask.unsqueeze(1)
+
+ for i in range(len(self.instance_feature_queue)):
+ temp_feature = self.instance_feature_queue[i]
+ temp_feature = (
+ match[..., None] * temp_feature[:, None]
+ ).sum(dim=2)
+ self.instance_feature_queue[i] = temp_feature
+
+ temp_anchor = self.anchor_queue[i]
+ temp_anchor = (
+ match[..., None] * temp_anchor[:, None]
+ ).sum(dim=2)
+ self.anchor_queue[i] = temp_anchor
+
+ self.period = (
+ match * self.period[:, None]
+ ).sum(dim=2)
+
+ self.instance_feature_queue.append(instance_feature.detach())
+ self.anchor_queue.append(det_anchors.detach())
+ self.period += 1
+
+ if len(self.instance_feature_queue) > self.queue_length:
+ self.instance_feature_queue.pop(0)
+ self.anchor_queue.pop(0)
+ self.period = torch.clip(self.period, 0, self.queue_length)
+
+ def prepare_planning(
+ self,
+ feature_maps,
+ mask,
+ batch_size,
+ ):
+ ## ego instance init
+ feature_maps_inv = feature_maps_format(feature_maps, inverse=True)
+ feature_map = feature_maps_inv[0][-1][:, 0]
+ ego_feature = self.ego_feature_encoder(feature_map)
+ ego_feature = ego_feature.unsqueeze(1).squeeze(-1).squeeze(-1)
+
+ ego_anchor = torch.tile(
+ self.ego_anchor[None], (batch_size, 1, 1)
+ )
+ if self.prev_ego_status is not None:
+ prev_ego_status = torch.where(
+ mask[:, None, None],
+ self.prev_ego_status,
+ self.prev_ego_status.new_tensor(0),
+ )
+ ego_anchor[..., VY] = prev_ego_status[..., 6]
+
+ if self.ego_period == None:
+ self.ego_period = ego_feature.new_zeros((batch_size, 1)).long()
+ else:
+ self.ego_period = torch.where(
+ mask[:, None],
+ self.ego_period,
+ self.ego_period.new_tensor(0),
+ )
+
+ self.ego_feature_queue.append(ego_feature.detach())
+ self.ego_anchor_queue.append(ego_anchor.detach())
+ self.ego_period += 1
+
+ if len(self.ego_feature_queue) > self.queue_length:
+ self.ego_feature_queue.pop(0)
+ self.ego_anchor_queue.pop(0)
+ self.ego_period = torch.clip(self.ego_period, 0, self.queue_length)
+
+ return ego_feature, ego_anchor
+
+ def cache_motion(self, instance_feature, det_output, metas):
+ det_classification = det_output["classification"][-1].sigmoid()
+ det_confidence = det_classification.max(dim=-1).values
+ instance_id = det_output['instance_id']
+ self.metas = metas
+ self.prev_confidence = det_confidence.detach()
+ self.prev_instance_id = instance_id
+
+ def cache_planning(self, ego_feature, ego_status):
+ self.prev_ego_status = ego_status.detach()
+ self.ego_feature_queue[-1] = ego_feature.detach()
diff --git a/projects/mmdet3d_plugin/models/motion/kinematic_motion_planning_head.py b/projects/mmdet3d_plugin/models/motion/kinematic_motion_planning_head.py
new file mode 100644
index 0000000..88f0bf2
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/motion/kinematic_motion_planning_head.py
@@ -0,0 +1,409 @@
+import torch
+
+from mmcv.runner import BaseModule, force_fp32
+from mmcv.utils import build_from_cfg
+from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
+from mmdet.core import reduce_mean
+from mmdet.core.bbox.builder import BBOX_SAMPLERS, BBOX_CODERS
+from mmdet.models import HEADS, build_loss
+
+from projects.mmdet3d_plugin.core.box3d import VX, VY
+
+
+@HEADS.register_module()
+class KinematicMotionPlanningHead(BaseModule):
+ """Heuristic kinematic baseline for motion and ego-planning prediction.
+
+ Implements the CTRA (Constant Turn Rate and Acceleration) kinematic model.
+ Special cases are selected via config flags:
+
+ use_acceleration use_turn_rate Model
+ ──────────────────────────────────────
+ False False CV (Constant Velocity)
+ True False CA (Constant Acceleration)
+ False True CVTR (Constant Velocity + Turn Rate)
+ True True CTRA (Constant Turn Rate + Acceleration)
+
+ No prediction parameters are trained. Trajectories are derived entirely
+ from the GT box state stored in ``det_output["prediction"][-1]`` and, when
+ acceleration/turn-rate estimation is enabled, the previous-frame anchor
+ history in ``instance_queue.anchor_queue``.
+
+ CTRA integration
+ ----------------
+ Given scalar speed ``s₀``, heading ``θ₀``, turn rate ``ω`` and longitudinal
+ acceleration ``a`` (all estimated at the current frame), each future
+ timestep delta is computed via midpoint integration::
+
+ s_mid(t) = s₀ + a · (t + ½) · dt (clipped to ≥ 0)
+ θ_mid(t) = θ₀ + ω · (t + ½) · dt
+ Δx_t = s_mid(t) · cos(θ_mid(t)) · dt
+ Δy_t = s_mid(t) · sin(θ_mid(t)) · dt
+
+ where ``t`` is zero-indexed over ``fut_ts`` steps.
+
+ When ``anchor_queue`` is ``None`` (first frame of a sequence), acceleration
+ and turn-rate estimates are unavailable and the model falls back to CV.
+
+ Interface
+ ---------
+ Signature-compatible with ``MotionPlanningHead``; swap via config
+ ``motion_plan_head.type``.
+
+ Args:
+ fut_ts: Agent future timesteps.
+ fut_mode: Number of trajectory hypothesis modes.
+ ego_fut_ts: Ego future timesteps.
+ ego_fut_mode: Number of ego hypothesis modes.
+ dt: Seconds per timestep (0.5 s for nuScenes).
+ use_acceleration: Estimate longitudinal acceleration from the
+ velocity delta between current and previous frame.
+ use_turn_rate: Estimate yaw rate from the heading delta between
+ current and previous frame.
+ instance_queue: Config for InstanceQueue (anchor history tracking).
+ motion_sampler / motion_loss_{cls,reg}: Motion loss modules.
+ planning_sampler / plan_loss_{cls,reg,status}: Planning loss modules.
+ motion_decoder / planning_decoder: Decoder configs.
+ num_det / num_map: API compatibility (unused).
+ **kwargs: Absorbs unused ``MotionPlanningHead`` params so that
+ the same config dict works with just a type change.
+ """
+
+ def __init__(
+ self,
+ fut_ts=12,
+ fut_mode=6,
+ ego_fut_ts=6,
+ ego_fut_mode=6,
+ dt=0.5,
+ use_acceleration=False,
+ use_turn_rate=False,
+ instance_queue=None,
+ motion_sampler=None,
+ motion_loss_cls=None,
+ motion_loss_reg=None,
+ planning_sampler=None,
+ plan_loss_cls=None,
+ plan_loss_reg=None,
+ plan_loss_status=None,
+ motion_decoder=None,
+ planning_decoder=None,
+ num_det=50,
+ num_map=10,
+ init_cfg=None,
+ ):
+ super().__init__(init_cfg)
+ self.fut_ts = fut_ts
+ self.fut_mode = fut_mode
+ self.ego_fut_ts = ego_fut_ts
+ self.ego_fut_mode = ego_fut_mode
+ self.dt = dt
+ self.use_acceleration = use_acceleration
+ self.use_turn_rate = use_turn_rate
+ self.num_det = num_det
+ self.num_map = num_map
+
+ def _build(cfg, registry):
+ return build_from_cfg(cfg, registry) if cfg is not None else None
+
+ self.instance_queue = _build(instance_queue, PLUGIN_LAYERS)
+ if self.instance_queue is not None:
+ # ego_feature_encoder is unused in kinematic mode; freeze to avoid
+ # DDP "unused parameter" errors.
+ for p in self.instance_queue.ego_feature_encoder.parameters():
+ p.requires_grad_(False)
+
+ self.motion_sampler = _build(motion_sampler, BBOX_SAMPLERS)
+ self.planning_sampler = _build(planning_sampler, BBOX_SAMPLERS)
+ self.motion_decoder = _build(motion_decoder, BBOX_CODERS)
+ self.planning_decoder = _build(planning_decoder, BBOX_CODERS)
+
+ self.motion_loss_cls = build_loss(motion_loss_cls)
+ self.motion_loss_reg = build_loss(motion_loss_reg)
+ self.plan_loss_cls = build_loss(plan_loss_cls)
+ self.plan_loss_reg = build_loss(plan_loss_reg)
+ self.plan_loss_status = build_loss(plan_loss_status)
+
+ def init_weights(self):
+ if self.instance_queue is not None:
+ for m in self.instance_queue.modules():
+ if hasattr(m, "init_weight"):
+ m.init_weight()
+
+ # ------------------------------------------------------------------ #
+ # Kinematic helpers
+ # ------------------------------------------------------------------ #
+
+ def _ctra_integrate(self, speed, heading, accel, omega, fut_ts, dt, device):
+ """Midpoint-rule CTRA integration.
+
+ Args:
+ speed: (bs, N) or (bs,) — current scalar speed.
+ heading: (bs, N) or (bs,) — current heading in radians.
+ accel: same shape — longitudinal acceleration (m/s²).
+ omega: same shape — yaw rate (rad/s).
+ fut_ts: number of future timesteps.
+ dt: seconds per timestep.
+
+ Returns:
+ pred: (*shape, fut_ts, 2) per-step XY deltas.
+ """
+ shape = speed.shape
+ t = torch.arange(fut_ts, device=device, dtype=speed.dtype) # (fut_ts,)
+ t_mid = t + 0.5 # midpoint
+
+ # Broadcast: (*shape, 1) × (fut_ts,) → (*shape, fut_ts)
+ s_mid = (speed.unsqueeze(-1) + accel.unsqueeze(-1) * t_mid * dt).clamp(min=0)
+ h_mid = heading.unsqueeze(-1) + omega.unsqueeze(-1) * t_mid * dt
+
+ dx = s_mid * torch.cos(h_mid) * dt # (*shape, fut_ts)
+ dy = s_mid * torch.sin(h_mid) * dt
+ return torch.stack([dx, dy], dim=-1) # (*shape, fut_ts, 2)
+
+ def _agent_kinematics(self, gt_anchors):
+ """Return (speed, heading, accel, omega) for each agent anchor.
+
+ Falls back to zero accel/omega if anchor_queue is unavailable (first
+ frame) or if the corresponding flag is disabled.
+
+ Args:
+ gt_anchors: (bs, N, 11) encoded GT anchors.
+
+ Returns:
+ speed, heading, accel, omega — each (bs, N).
+ """
+ vx = gt_anchors[..., VX]
+ vy = gt_anchors[..., VY]
+ speed = torch.sqrt(vx ** 2 + vy ** 2).clamp(min=1e-6)
+ heading = torch.atan2(vy, vx)
+
+ # anchor_queue[-1] is the current frame (just appended by prepare_motion),
+ # so the actual previous frame is at index -2.
+ queue = self.instance_queue.anchor_queue # list of (bs, N, 11) tensors
+ have_history = len(queue) >= 2
+
+ if self.use_acceleration and have_history:
+ prev = queue[-2] # (bs, N, 11) — previous frame
+ prev_speed = torch.sqrt(
+ prev[..., VX] ** 2 + prev[..., VY] ** 2
+ ).clamp(min=1e-6)
+ accel = (speed - prev_speed) / self.dt
+ else:
+ accel = torch.zeros_like(speed)
+
+ if self.use_turn_rate and have_history:
+ prev = queue[-2] # (bs, N, 11) — previous frame
+ prev_heading = torch.atan2(prev[..., VY], prev[..., VX])
+ omega = (heading - prev_heading) / self.dt
+ else:
+ omega = torch.zeros_like(heading)
+
+ return speed, heading, accel, omega
+
+ def _ego_kinematics(self, ego_status, device):
+ """Return (speed, heading, accel, omega) for the ego vehicle.
+
+ Args:
+ ego_status: (bs, 9) — index 6/7 = vx/vy in current frame.
+
+ Returns:
+ speed, heading, accel, omega — each (bs,).
+ """
+ vx = ego_status[:, 6]
+ vy = ego_status[:, 7]
+ speed = torch.sqrt(vx ** 2 + vy ** 2).clamp(min=1e-6)
+ heading = torch.atan2(vy, vx)
+
+ # ego_anchor_queue[-1] stores the previous frame's velocity in VX/VY slots,
+ # but on the very first frame prev_ego_status was None so VX=VY=0.
+ # Require len >= 2 so the first frame always falls back to zero accel/omega.
+ ego_queue = self.instance_queue.ego_anchor_queue # list of (bs, 1, 11) tensors
+ have_history = len(ego_queue) >= 2
+
+ if self.use_acceleration and have_history:
+ prev_ego = ego_queue[-1][:, 0] # (bs, 11)
+ prev_vx = prev_ego[..., VX] # (bs,)
+ prev_vy = prev_ego[..., VY] # (bs,)
+ prev_speed = torch.sqrt(prev_vx ** 2 + prev_vy ** 2).clamp(min=1e-6)
+ accel = (speed - prev_speed) / self.dt
+ else:
+ accel = torch.zeros_like(speed)
+
+ if self.use_turn_rate and have_history:
+ prev_ego = ego_queue[-1][:, 0] # (bs, 11)
+ prev_heading = torch.atan2(prev_ego[..., VY], prev_ego[..., VX])
+ omega = (heading - prev_heading) / self.dt
+ else:
+ omega = torch.zeros_like(heading)
+
+ return speed, heading, accel, omega
+
+ # ------------------------------------------------------------------ #
+ # Forward
+ # ------------------------------------------------------------------ #
+
+ def forward(
+ self,
+ det_output,
+ map_output,
+ feature_maps,
+ metas,
+ anchor_encoder,
+ mask,
+ anchor_handler,
+ ):
+ bs = len(metas["img_metas"])
+ gt_anchors = det_output["prediction"][-1] # (bs, num_anchor, 11)
+ num_anchor = gt_anchors.shape[1]
+ device = gt_anchors.device
+
+ # Update instance_queue for anchor history tracking.
+ ego_feature, ego_anchor, _, _, _ = self.instance_queue.get(
+ det_output, feature_maps, metas, bs, mask, anchor_handler
+ )
+
+ # -------- agent motion -------------------------------------------- #
+ speed, heading, accel, omega = self._agent_kinematics(gt_anchors)
+ # pred: (bs, N, fut_ts, 2) → expand modes → (bs, N, fut_mode, fut_ts, 2)
+ pred = self._ctra_integrate(speed, heading, accel, omega,
+ self.fut_ts, self.dt, device)
+ motion_pred = (
+ pred[:, :, None, :, :]
+ .expand(bs, num_anchor, self.fut_mode, self.fut_ts, 2)
+ .contiguous()
+ )
+ motion_cls = gt_anchors.new_zeros(bs, num_anchor, self.fut_mode)
+
+ # -------- ego planning -------------------------------------------- #
+ ego_status = metas["ego_status"] # (bs, 9)
+ e_speed, e_heading, e_accel, e_omega = self._ego_kinematics(
+ ego_status, device
+ )
+ # pred: (bs, ego_fut_ts, 2) → expand → (bs, 3*ego_fut_mode, ego_fut_ts, 2)
+ ego_pred = self._ctra_integrate(e_speed, e_heading, e_accel, e_omega,
+ self.ego_fut_ts, self.dt, device)
+ plan_pred = (
+ ego_pred[:, None, :, :]
+ .expand(bs, 3 * self.ego_fut_mode, self.ego_fut_ts, 2)
+ .contiguous()
+ )
+ plan_cls = gt_anchors.new_zeros(bs, 3 * self.ego_fut_mode)
+ plan_status = ego_status.unsqueeze(1) # (bs, 1, 9)
+
+ # -------- update queue state -------------------------------------- #
+ zero_feats = gt_anchors.new_zeros(
+ bs, num_anchor, det_output["instance_feature"].shape[-1]
+ )
+ self.instance_queue.cache_motion(zero_feats, det_output, metas)
+ self.instance_queue.cache_planning(ego_feature, plan_status)
+
+ motion_output = {
+ "classification": [motion_cls],
+ "prediction": [motion_pred],
+ "period": self.instance_queue.period,
+ "anchor_queue": self.instance_queue.anchor_queue,
+ }
+ planning_output = {
+ "classification": [plan_cls],
+ "prediction": [plan_pred],
+ "status": [plan_status],
+ "period": self.instance_queue.ego_period,
+ "anchor_queue": self.instance_queue.ego_anchor_queue,
+ }
+ return motion_output, planning_output
+
+ # ------------------------------------------------------------------ #
+ # Loss (computed for monitoring; no params are optimised)
+ # ------------------------------------------------------------------ #
+
+ @force_fp32(apply_to=("model_outs",))
+ def loss(self, motion_model_outs, planning_model_outs, data, motion_loss_cache):
+ loss = {}
+ loss.update(self.loss_motion(motion_model_outs, data, motion_loss_cache))
+ loss.update(self.loss_planning(planning_model_outs, data))
+ return loss
+
+ @force_fp32(apply_to=("model_outs",))
+ def loss_motion(self, model_outs, data, motion_loss_cache):
+ output = {}
+ for i, (cls, reg) in enumerate(zip(
+ model_outs["classification"], model_outs["prediction"]
+ )):
+ cls_target, cls_weight, reg_pred, reg_target, reg_weight, num_pos = (
+ self.motion_sampler.sample(
+ reg,
+ data["gt_agent_fut_trajs"],
+ data["gt_agent_fut_masks"],
+ motion_loss_cache,
+ )
+ )
+ num_pos = max(reduce_mean(num_pos), 1.0)
+
+ cls_loss = self.motion_loss_cls(
+ cls.flatten(end_dim=1),
+ cls_target.flatten(end_dim=1),
+ weight=cls_weight.flatten(end_dim=1),
+ avg_factor=num_pos,
+ )
+ reg_pred = reg_pred.flatten(end_dim=1).cumsum(dim=-2)
+ reg_target = reg_target.flatten(end_dim=1).cumsum(dim=-2)
+ reg_loss = self.motion_loss_reg(
+ reg_pred, reg_target,
+ weight=reg_weight.flatten(end_dim=1).unsqueeze(-1),
+ avg_factor=num_pos,
+ )
+ output[f"motion_loss_cls_{i}"] = cls_loss
+ output[f"motion_loss_reg_{i}"] = reg_loss
+ return output
+
+ @force_fp32(apply_to=("model_outs",))
+ def loss_planning(self, model_outs, data):
+ output = {}
+ for i, (cls, reg, status) in enumerate(zip(
+ model_outs["classification"],
+ model_outs["prediction"],
+ model_outs["status"],
+ )):
+ cls, cls_target, cls_weight, reg_pred, reg_target, reg_weight = (
+ self.planning_sampler.sample(
+ cls, reg,
+ data["gt_ego_fut_trajs"],
+ data["gt_ego_fut_masks"],
+ data,
+ )
+ )
+ cls_loss = self.plan_loss_cls(
+ cls.flatten(end_dim=1),
+ cls_target.flatten(end_dim=1),
+ weight=cls_weight.flatten(end_dim=1),
+ )
+ reg_loss = self.plan_loss_reg(
+ reg_pred.flatten(end_dim=1),
+ reg_target.flatten(end_dim=1),
+ weight=reg_weight.flatten(end_dim=1).unsqueeze(-1),
+ )
+ status_loss = self.plan_loss_status(
+ status.squeeze(1), data["ego_status"]
+ )
+ output[f"planning_loss_cls_{i}"] = cls_loss
+ output[f"planning_loss_reg_{i}"] = reg_loss
+ output[f"planning_loss_status_{i}"] = status_loss
+ return output
+
+ # ------------------------------------------------------------------ #
+ # Post-process
+ # ------------------------------------------------------------------ #
+
+ @force_fp32(apply_to=("model_outs",))
+ def post_process(self, det_output, motion_output, planning_output, data):
+ motion_result = self.motion_decoder.decode(
+ det_output["classification"],
+ det_output["prediction"],
+ det_output.get("instance_id"),
+ det_output.get("quality"),
+ motion_output,
+ )
+ planning_result = self.planning_decoder.decode(
+ det_output, motion_output, planning_output, data
+ )
+ return motion_result, planning_result
diff --git a/projects/mmdet3d_plugin/models/motion/motion_blocks.py b/projects/mmdet3d_plugin/models/motion/motion_blocks.py
new file mode 100644
index 0000000..f7e80e2
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/motion/motion_blocks.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+from mmcv.cnn import Linear, Scale, bias_init_with_prob
+from mmcv.runner.base_module import Sequential, BaseModule
+from mmcv.cnn import xavier_init
+from mmcv.cnn.bricks.registry import (
+ PLUGIN_LAYERS,
+)
+
+from projects.mmdet3d_plugin.core.box3d import *
+from ..blocks import linear_relu_ln
+
+
+@PLUGIN_LAYERS.register_module()
+class MotionPlanningRefinementModule(BaseModule):
+ def __init__(
+ self,
+ embed_dims=256,
+ fut_ts=12,
+ fut_mode=6,
+ ego_fut_ts=6,
+ ego_fut_mode=3,
+ ):
+ super(MotionPlanningRefinementModule, self).__init__()
+ self.embed_dims = embed_dims
+ self.fut_ts = fut_ts
+ self.fut_mode = fut_mode
+ self.ego_fut_ts = ego_fut_ts
+ self.ego_fut_mode = ego_fut_mode
+
+ self.motion_cls_branch = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 2),
+ Linear(embed_dims, 1),
+ )
+ self.motion_reg_branch = nn.Sequential(
+ nn.Linear(embed_dims, embed_dims),
+ nn.ReLU(),
+ nn.Linear(embed_dims, embed_dims),
+ nn.ReLU(),
+ nn.Linear(embed_dims, fut_ts * 2),
+ )
+ self.plan_cls_branch = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 2),
+ Linear(embed_dims, 1),
+ )
+ self.plan_reg_branch = nn.Sequential(
+ nn.Linear(embed_dims, embed_dims),
+ nn.ReLU(),
+ nn.Linear(embed_dims, embed_dims),
+ nn.ReLU(),
+ nn.Linear(embed_dims, ego_fut_ts * 2),
+ )
+ self.plan_status_branch = nn.Sequential(
+ nn.Linear(embed_dims, embed_dims),
+ nn.ReLU(),
+ nn.Linear(embed_dims, embed_dims),
+ nn.ReLU(),
+ nn.Linear(embed_dims, 10),
+ )
+
+ def init_weight(self):
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.motion_cls_branch[-1].bias, bias_init)
+ nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init)
+
+ def forward(
+ self,
+ motion_query,
+ plan_query,
+ ego_feature,
+ ego_anchor_embed,
+ ):
+ bs, num_anchor = motion_query.shape[:2]
+ motion_cls = self.motion_cls_branch(motion_query).squeeze(-1)
+ motion_reg = self.motion_reg_branch(motion_query).reshape(bs, num_anchor, self.fut_mode, self.fut_ts, 2)
+ plan_cls = self.plan_cls_branch(plan_query).squeeze(-1)
+ plan_reg = self.plan_reg_branch(plan_query).reshape(bs, 1, 3 * self.ego_fut_mode, self.ego_fut_ts, 2)
+ planning_status = self.plan_status_branch(ego_feature + ego_anchor_embed)
+ return motion_cls, motion_reg, plan_cls, plan_reg, planning_status
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/motion/motion_planning_head.py b/projects/mmdet3d_plugin/models/motion/motion_planning_head.py
new file mode 100644
index 0000000..64d0b2b
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/motion/motion_planning_head.py
@@ -0,0 +1,509 @@
+from typing import List, Optional, Tuple, Union
+import warnings
+import copy
+
+import numpy as np
+import cv2
+import torch
+import torch.nn as nn
+
+from mmcv.utils import build_from_cfg
+from mmcv.cnn import Linear, bias_init_with_prob
+from mmcv.runner import BaseModule, force_fp32
+from mmcv.cnn.bricks.registry import (
+ ATTENTION,
+ PLUGIN_LAYERS,
+ POSITIONAL_ENCODING,
+ FEEDFORWARD_NETWORK,
+ NORM_LAYERS,
+)
+from mmdet.core import reduce_mean
+from mmdet.models import HEADS
+from mmdet.core.bbox.builder import BBOX_SAMPLERS, BBOX_CODERS
+from mmdet.models import build_loss
+
+from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners
+from projects.mmdet3d_plugin.core.box3d import *
+
+from ..attention import gen_sineembed_for_position
+from ..blocks import linear_relu_ln
+from ..instance_bank import topk
+
+
+@HEADS.register_module()
+class MotionPlanningHead(BaseModule):
+ def __init__(
+ self,
+ fut_ts=12,
+ fut_mode=6,
+ ego_fut_ts=6,
+ ego_fut_mode=3,
+ motion_anchor=None,
+ plan_anchor=None,
+ embed_dims=256,
+ decouple_attn=False,
+ instance_queue=None,
+ operation_order=None,
+ temp_graph_model=None,
+ graph_model=None,
+ cross_graph_model=None,
+ deformable_model=None,
+ norm_layer=None,
+ ffn=None,
+ refine_layer=None,
+ motion_sampler=None,
+ motion_loss_cls=None,
+ motion_loss_reg=None,
+ planning_sampler=None,
+ plan_loss_cls=None,
+ plan_loss_reg=None,
+ plan_loss_status=None,
+ motion_decoder=None,
+ planning_decoder=None,
+ num_det=50,
+ num_map=10,
+ ):
+ super(MotionPlanningHead, self).__init__()
+ self.fut_ts = fut_ts
+ self.fut_mode = fut_mode
+ self.ego_fut_ts = ego_fut_ts
+ self.ego_fut_mode = ego_fut_mode
+
+ self.decouple_attn = decouple_attn
+ self.operation_order = operation_order
+
+ # =========== build modules ===========
+ def build(cfg, registry):
+ if cfg is None:
+ return None
+ return build_from_cfg(cfg, registry)
+
+ self.instance_queue = build(instance_queue, PLUGIN_LAYERS)
+ self.motion_sampler = build(motion_sampler, BBOX_SAMPLERS)
+ self.planning_sampler = build(planning_sampler, BBOX_SAMPLERS)
+ self.motion_decoder = build(motion_decoder, BBOX_CODERS)
+ self.planning_decoder = build(planning_decoder, BBOX_CODERS)
+ self.op_config_map = {
+ "temp_gnn": [temp_graph_model, ATTENTION],
+ "gnn": [graph_model, ATTENTION],
+ "cross_gnn": [cross_graph_model, ATTENTION],
+ "deformable": [deformable_model, ATTENTION],
+ "norm": [norm_layer, NORM_LAYERS],
+ "ffn": [ffn, FEEDFORWARD_NETWORK],
+ "refine": [refine_layer, PLUGIN_LAYERS],
+ }
+ self.layers = nn.ModuleList(
+ [
+ build(*self.op_config_map.get(op, [None, None]))
+ for op in self.operation_order
+ ]
+ )
+ self.embed_dims = embed_dims
+
+ if self.decouple_attn:
+ self.fc_before = nn.Linear(
+ self.embed_dims, self.embed_dims * 2, bias=False
+ )
+ self.fc_after = nn.Linear(
+ self.embed_dims * 2, self.embed_dims, bias=False
+ )
+ else:
+ self.fc_before = nn.Identity()
+ self.fc_after = nn.Identity()
+
+ self.motion_loss_cls = build_loss(motion_loss_cls)
+ self.motion_loss_reg = build_loss(motion_loss_reg)
+ self.plan_loss_cls = build_loss(plan_loss_cls)
+ self.plan_loss_reg = build_loss(plan_loss_reg)
+ self.plan_loss_status = build_loss(plan_loss_status)
+
+ # motion init
+ motion_anchor = np.load(motion_anchor)
+ self.motion_anchor = nn.Parameter(
+ torch.tensor(motion_anchor, dtype=torch.float32),
+ requires_grad=False,
+ )
+ self.motion_anchor_encoder = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 1),
+ Linear(embed_dims, embed_dims),
+ )
+
+ # plan anchor init
+ plan_anchor = np.load(plan_anchor)
+ self.plan_anchor = nn.Parameter(
+ torch.tensor(plan_anchor, dtype=torch.float32),
+ requires_grad=False,
+ )
+ self.plan_anchor_encoder = nn.Sequential(
+ *linear_relu_ln(embed_dims, 1, 1),
+ Linear(embed_dims, embed_dims),
+ )
+
+ self.num_det = num_det
+ self.num_map = num_map
+
+ def init_weights(self):
+ for i, op in enumerate(self.operation_order):
+ if self.layers[i] is None:
+ continue
+ elif op != "refine":
+ for p in self.layers[i].parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if hasattr(m, "init_weight"):
+ m.init_weight()
+
+ def get_motion_anchor(
+ self,
+ classification,
+ prediction,
+ ):
+ cls_ids = classification.argmax(dim=-1)
+ motion_anchor = self.motion_anchor[cls_ids]
+ prediction = prediction.detach()
+ return self._agent2lidar(motion_anchor, prediction)
+
+ def _agent2lidar(self, trajs, boxes):
+ yaw = torch.atan2(boxes[..., SIN_YAW], boxes[..., COS_YAW])
+ cos_yaw = torch.cos(yaw)
+ sin_yaw = torch.sin(yaw)
+ rot_mat_T = torch.stack(
+ [
+ torch.stack([cos_yaw, sin_yaw]),
+ torch.stack([-sin_yaw, cos_yaw]),
+ ]
+ )
+
+ trajs_lidar = torch.einsum('abcij,jkab->abcik', trajs, rot_mat_T)
+ return trajs_lidar
+
+ def graph_model(
+ self,
+ index,
+ query,
+ key=None,
+ value=None,
+ query_pos=None,
+ key_pos=None,
+ **kwargs,
+ ):
+ if self.decouple_attn:
+ query = torch.cat([query, query_pos], dim=-1)
+ if key is not None:
+ key = torch.cat([key, key_pos], dim=-1)
+ query_pos, key_pos = None, None
+ if value is not None:
+ value = self.fc_before(value)
+ return self.fc_after(
+ self.layers[index](
+ query,
+ key,
+ value,
+ query_pos=query_pos,
+ key_pos=key_pos,
+ **kwargs,
+ )
+ )
+
+ def forward(
+ self,
+ det_output,
+ map_output,
+ feature_maps,
+ metas,
+ anchor_encoder,
+ mask,
+ anchor_handler,
+ ):
+ # =========== det/map feature/anchor ===========
+ instance_feature = det_output["instance_feature"]
+ anchor_embed = det_output["anchor_embed"]
+ det_classification = det_output["classification"][-1].sigmoid()
+ det_anchors = det_output["prediction"][-1]
+ det_confidence = det_classification.max(dim=-1).values
+ _, (instance_feature_selected, anchor_embed_selected) = topk(
+ det_confidence, self.num_det, instance_feature, anchor_embed
+ )
+
+ if map_output is not None:
+ map_instance_feature = map_output["instance_feature"]
+ map_anchor_embed = map_output["anchor_embed"]
+ map_classification = map_output["classification"][-1].sigmoid()
+ map_anchors = map_output["prediction"][-1]
+ map_confidence = map_classification.max(dim=-1).values
+ _, (map_instance_feature_selected, map_anchor_embed_selected) = topk(
+ map_confidence, self.num_map, map_instance_feature, map_anchor_embed
+ )
+
+ # =========== get ego/temporal feature/anchor ===========
+ bs, num_anchor, dim = instance_feature.shape
+ (
+ ego_feature,
+ ego_anchor,
+ temp_instance_feature,
+ temp_anchor,
+ temp_mask,
+ ) = self.instance_queue.get(
+ det_output,
+ feature_maps,
+ metas,
+ bs,
+ mask,
+ anchor_handler,
+ )
+ ego_anchor_embed = anchor_encoder(ego_anchor)
+ temp_anchor_embed = anchor_encoder(temp_anchor)
+ temp_instance_feature = temp_instance_feature.flatten(0, 1)
+ temp_anchor_embed = temp_anchor_embed.flatten(0, 1)
+ temp_mask = temp_mask.flatten(0, 1)
+
+ # =========== mode anchor init ===========
+ motion_anchor = self.get_motion_anchor(det_classification, det_anchors)
+ plan_anchor = torch.tile(
+ self.plan_anchor[None], (bs, 1, 1, 1, 1)
+ )
+
+ # =========== mode query init ===========
+ motion_mode_query = self.motion_anchor_encoder(gen_sineembed_for_position(motion_anchor[..., -1, :]))
+ plan_pos = gen_sineembed_for_position(plan_anchor[..., -1, :])
+ plan_mode_query = self.plan_anchor_encoder(plan_pos).flatten(1, 2).unsqueeze(1)
+
+ # =========== cat instance and ego ===========
+ instance_feature_selected = torch.cat([instance_feature_selected, ego_feature], dim=1)
+ anchor_embed_selected = torch.cat([anchor_embed_selected, ego_anchor_embed], dim=1)
+
+ instance_feature = torch.cat([instance_feature, ego_feature], dim=1)
+ anchor_embed = torch.cat([anchor_embed, ego_anchor_embed], dim=1)
+
+ # =================== forward the layers ====================
+ motion_classification = []
+ motion_prediction = []
+ planning_classification = []
+ planning_prediction = []
+ planning_status = []
+ for i, op in enumerate(self.operation_order):
+ if self.layers[i] is None:
+ continue
+ elif op == "temp_gnn":
+ instance_feature = self.graph_model(
+ i,
+ instance_feature.flatten(0, 1).unsqueeze(1),
+ temp_instance_feature,
+ temp_instance_feature,
+ query_pos=anchor_embed.flatten(0, 1).unsqueeze(1),
+ key_pos=temp_anchor_embed,
+ key_padding_mask=temp_mask,
+ )
+ instance_feature = instance_feature.reshape(bs, num_anchor + 1, dim)
+ elif op == "gnn":
+ instance_feature = self.graph_model(
+ i,
+ instance_feature,
+ instance_feature_selected,
+ instance_feature_selected,
+ query_pos=anchor_embed,
+ key_pos=anchor_embed_selected,
+ )
+ elif op == "norm" or op == "ffn":
+ instance_feature = self.layers[i](instance_feature)
+ elif op == "cross_gnn":
+ instance_feature = self.layers[i](
+ instance_feature,
+ key=map_instance_feature_selected,
+ query_pos=anchor_embed,
+ key_pos=map_anchor_embed_selected,
+ )
+ elif op == "deformable":
+ # Apply deformable cross-attention to sensor features for
+ # agent instances only (ego token has no well-defined 3D box).
+ agent_feature = self.layers[i](
+ instance_feature[:, :num_anchor],
+ det_anchors,
+ anchor_embed[:, :num_anchor],
+ feature_maps,
+ metas,
+ )
+ instance_feature = torch.cat(
+ [agent_feature, instance_feature[:, num_anchor:]], dim=1
+ )
+ elif op == "refine":
+ motion_query = motion_mode_query + (instance_feature + anchor_embed)[:, :num_anchor].unsqueeze(2)
+ plan_query = plan_mode_query + (instance_feature + anchor_embed)[:, num_anchor:].unsqueeze(2)
+ (
+ motion_cls,
+ motion_reg,
+ plan_cls,
+ plan_reg,
+ plan_status,
+ ) = self.layers[i](
+ motion_query,
+ plan_query,
+ instance_feature[:, num_anchor:],
+ anchor_embed[:, num_anchor:],
+ )
+ motion_classification.append(motion_cls)
+ motion_prediction.append(motion_reg)
+ planning_classification.append(plan_cls)
+ planning_prediction.append(plan_reg)
+ planning_status.append(plan_status)
+ # Update mode anchor queries for the next decoder iteration.
+ # cumsum converts delta trajectories to absolute endpoints.
+ motion_anchor_upd = motion_reg.detach().cumsum(dim=-2)
+ motion_mode_query = self.motion_anchor_encoder(
+ gen_sineembed_for_position(motion_anchor_upd[..., -1, :])
+ )
+ plan_anchor_upd = plan_reg.detach().cumsum(dim=-2)
+ plan_mode_query = self.plan_anchor_encoder(
+ gen_sineembed_for_position(plan_anchor_upd[..., -1, :])
+ ).flatten(1, 2).unsqueeze(1)
+
+ self.instance_queue.cache_motion(instance_feature[:, :num_anchor], det_output, metas)
+ self.instance_queue.cache_planning(instance_feature[:, num_anchor:], plan_status)
+
+ motion_output = {
+ "classification": motion_classification,
+ "prediction": motion_prediction,
+ "period": self.instance_queue.period,
+ "anchor_queue": self.instance_queue.anchor_queue,
+ }
+ planning_output = {
+ "classification": planning_classification,
+ "prediction": planning_prediction,
+ "status": planning_status,
+ "period": self.instance_queue.ego_period,
+ "anchor_queue": self.instance_queue.ego_anchor_queue,
+ }
+ return motion_output, planning_output
+
+ def loss(self,
+ motion_model_outs,
+ planning_model_outs,
+ data,
+ motion_loss_cache
+ ):
+ loss = {}
+ motion_loss = self.loss_motion(motion_model_outs, data, motion_loss_cache)
+ loss.update(motion_loss)
+ planning_loss = self.loss_planning(planning_model_outs, data)
+ loss.update(planning_loss)
+ return loss
+
+ @force_fp32(apply_to=("model_outs"))
+ def loss_motion(self, model_outs, data, motion_loss_cache):
+ cls_scores = model_outs["classification"]
+ reg_preds = model_outs["prediction"]
+ output = {}
+ for decoder_idx, (cls, reg) in enumerate(
+ zip(cls_scores, reg_preds)
+ ):
+ (
+ cls_target,
+ cls_weight,
+ reg_pred,
+ reg_target,
+ reg_weight,
+ num_pos
+ ) = self.motion_sampler.sample(
+ reg,
+ data["gt_agent_fut_trajs"],
+ data["gt_agent_fut_masks"],
+ motion_loss_cache,
+ )
+ num_pos = max(reduce_mean(num_pos), 1.0)
+
+ cls = cls.flatten(end_dim=1)
+ cls_target = cls_target.flatten(end_dim=1)
+ cls_weight = cls_weight.flatten(end_dim=1)
+ cls_loss = self.motion_loss_cls(cls, cls_target, weight=cls_weight, avg_factor=num_pos)
+
+ reg_weight = reg_weight.flatten(end_dim=1)
+ reg_pred = reg_pred.flatten(end_dim=1)
+ reg_target = reg_target.flatten(end_dim=1)
+ reg_weight = reg_weight.unsqueeze(-1)
+ reg_pred = reg_pred.cumsum(dim=-2)
+ reg_target = reg_target.cumsum(dim=-2)
+ reg_loss = self.motion_loss_reg(
+ reg_pred, reg_target, weight=reg_weight, avg_factor=num_pos
+ )
+
+ output.update(
+ {
+ f"motion_loss_cls_{decoder_idx}": cls_loss,
+ f"motion_loss_reg_{decoder_idx}": reg_loss,
+ }
+ )
+
+ return output
+
+ @force_fp32(apply_to=("model_outs"))
+ def loss_planning(self, model_outs, data):
+ cls_scores = model_outs["classification"]
+ reg_preds = model_outs["prediction"]
+ status_preds = model_outs["status"]
+ output = {}
+ for decoder_idx, (cls, reg, status) in enumerate(
+ zip(cls_scores, reg_preds, status_preds)
+ ):
+ (
+ cls,
+ cls_target,
+ cls_weight,
+ reg_pred,
+ reg_target,
+ reg_weight,
+ ) = self.planning_sampler.sample(
+ cls,
+ reg,
+ data['gt_ego_fut_trajs'],
+ data['gt_ego_fut_masks'],
+ data,
+ )
+ cls = cls.flatten(end_dim=1)
+ cls_target = cls_target.flatten(end_dim=1)
+ cls_weight = cls_weight.flatten(end_dim=1)
+ cls_loss = self.plan_loss_cls(cls, cls_target, weight=cls_weight)
+
+ reg_weight = reg_weight.flatten(end_dim=1)
+ reg_pred = reg_pred.flatten(end_dim=1)
+ reg_target = reg_target.flatten(end_dim=1)
+ reg_weight = reg_weight.unsqueeze(-1)
+
+ reg_loss = self.plan_loss_reg(
+ reg_pred, reg_target, weight=reg_weight
+ )
+ status_loss = self.plan_loss_status(status.squeeze(1), data['ego_status'])
+
+ output.update(
+ {
+ f"planning_loss_cls_{decoder_idx}": cls_loss,
+ f"planning_loss_reg_{decoder_idx}": reg_loss,
+ f"planning_loss_status_{decoder_idx}": status_loss,
+ }
+ )
+
+ return output
+
+ @force_fp32(apply_to=("model_outs"))
+ def post_process(
+ self,
+ det_output,
+ motion_output,
+ planning_output,
+ data,
+ ):
+ motion_result = self.motion_decoder.decode(
+ det_output["classification"],
+ det_output["prediction"],
+ det_output.get("instance_id"),
+ det_output.get("quality"),
+ motion_output,
+ )
+ planning_result = self.planning_decoder.decode(
+ det_output,
+ motion_output,
+ planning_output,
+ data,
+ )
+
+ return motion_result, planning_result
\ No newline at end of file
diff --git a/projects/mmdet3d_plugin/models/motion/target.py b/projects/mmdet3d_plugin/models/motion/target.py
new file mode 100644
index 0000000..521044f
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/motion/target.py
@@ -0,0 +1,108 @@
+import torch
+
+from mmdet.core.bbox.builder import BBOX_SAMPLERS
+
+__all__ = ["MotionTarget", "PlanningTarget"]
+
+
+def get_cls_target(
+ reg_preds,
+ reg_target,
+ reg_weight,
+):
+ bs, num_pred, mode, ts, d = reg_preds.shape
+ reg_preds_cum = reg_preds.cumsum(dim=-2)
+ reg_target_cum = reg_target.cumsum(dim=-2)
+ dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1)
+ dist = dist * reg_weight.unsqueeze(2)
+ dist = dist.mean(dim=-1)
+ mode_idx = torch.argmin(dist, dim=-1)
+ return mode_idx
+
+def get_best_reg(
+ reg_preds,
+ reg_target,
+ reg_weight,
+):
+ bs, num_pred, mode, ts, d = reg_preds.shape
+ reg_preds_cum = reg_preds.cumsum(dim=-2)
+ reg_target_cum = reg_target.cumsum(dim=-2)
+ dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1)
+ dist = dist * reg_weight.unsqueeze(2)
+ dist = dist.mean(dim=-1)
+ mode_idx = torch.argmin(dist, dim=-1)
+ mode_idx = mode_idx[..., None, None, None].repeat(1, 1, 1, ts, d)
+ best_reg = torch.gather(reg_preds, 2, mode_idx).squeeze(2)
+ return best_reg
+
+
+@BBOX_SAMPLERS.register_module()
+class MotionTarget():
+ def __init__(
+ self,
+ ):
+ super(MotionTarget, self).__init__()
+
+ def sample(
+ self,
+ reg_pred,
+ gt_reg_target,
+ gt_reg_mask,
+ motion_loss_cache,
+ ):
+ bs, num_anchor, mode, ts, d = reg_pred.shape
+ reg_target = reg_pred.new_zeros((bs, num_anchor, ts, d))
+ reg_weight = reg_pred.new_zeros((bs, num_anchor, ts))
+ indices = motion_loss_cache['indices']
+ num_pos = reg_pred.new_tensor([0])
+ for i, (pred_idx, target_idx) in enumerate(indices):
+ if len(gt_reg_target[i]) == 0:
+ continue
+ reg_target[i, pred_idx] = gt_reg_target[i][target_idx]
+ reg_weight[i, pred_idx] = gt_reg_mask[i][target_idx]
+ num_pos += len(pred_idx)
+
+ cls_target = get_cls_target(reg_pred, reg_target, reg_weight)
+ cls_weight = reg_weight.any(dim=-1)
+ best_reg = get_best_reg(reg_pred, reg_target, reg_weight)
+
+ return cls_target, cls_weight, best_reg, reg_target, reg_weight, num_pos
+
+
+@BBOX_SAMPLERS.register_module()
+class PlanningTarget():
+ def __init__(
+ self,
+ ego_fut_ts,
+ ego_fut_mode,
+ ):
+ super(PlanningTarget, self).__init__()
+ self.ego_fut_ts = ego_fut_ts
+ self.ego_fut_mode = ego_fut_mode
+
+ def sample(
+ self,
+ cls_pred,
+ reg_pred,
+ gt_reg_target,
+ gt_reg_mask,
+ data,
+ ):
+ gt_reg_target = gt_reg_target.unsqueeze(1)
+ gt_reg_mask = gt_reg_mask.unsqueeze(1)
+
+ bs = reg_pred.shape[0]
+ bs_indices = torch.arange(bs, device=reg_pred.device)
+ cmd = data['gt_ego_fut_cmd'].argmax(dim=-1)
+
+ cls_pred = cls_pred.reshape(bs, 3, 1, self.ego_fut_mode)
+ reg_pred = reg_pred.reshape(bs, 3, 1, self.ego_fut_mode, self.ego_fut_ts, 2)
+ cls_pred = cls_pred[bs_indices, cmd]
+ reg_pred = reg_pred[bs_indices, cmd]
+ cls_target = get_cls_target(reg_pred, gt_reg_target, gt_reg_mask)
+ cls_weight = gt_reg_mask.any(dim=-1)
+ best_reg = get_best_reg(reg_pred, gt_reg_target, gt_reg_mask)
+
+ return cls_pred, cls_target, cls_weight, best_reg, gt_reg_target, gt_reg_mask
+
+
diff --git a/projects/mmdet3d_plugin/models/sparsedrive.py b/projects/mmdet3d_plugin/models/sparsedrive.py
new file mode 100644
index 0000000..16837dd
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/sparsedrive.py
@@ -0,0 +1,127 @@
+from inspect import signature
+
+import torch
+
+from mmcv.runner import force_fp32, auto_fp16
+from mmcv.utils import build_from_cfg
+from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
+from mmdet.models import (
+ DETECTORS,
+ BaseDetector,
+ build_backbone,
+ build_head,
+ build_neck,
+)
+from .grid_mask import GridMask
+
+try:
+ from ..ops import feature_maps_format
+ DAF_VALID = True
+except:
+ DAF_VALID = False
+
+__all__ = ["SparseDrive"]
+
+
+@DETECTORS.register_module()
+class SparseDrive(BaseDetector):
+ def __init__(
+ self,
+ img_backbone,
+ head,
+ img_neck=None,
+ init_cfg=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ use_grid_mask=True,
+ use_deformable_func=False,
+ depth_branch=None,
+ ):
+ super(SparseDrive, self).__init__(init_cfg=init_cfg)
+ if pretrained is not None:
+ backbone.pretrained = pretrained
+ self.img_backbone = build_backbone(img_backbone)
+ if img_neck is not None:
+ self.img_neck = build_neck(img_neck)
+ self.head = build_head(head)
+ self.use_grid_mask = use_grid_mask
+ if use_deformable_func:
+ assert DAF_VALID, "deformable_aggregation needs to be set up."
+ self.use_deformable_func = use_deformable_func
+ if depth_branch is not None:
+ self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
+ else:
+ self.depth_branch = None
+ if use_grid_mask:
+ self.grid_mask = GridMask(
+ True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
+ )
+
+ @auto_fp16(apply_to=("img",), out_fp32=True)
+ def extract_feat(self, img, return_depth=False, metas=None):
+ bs = img.shape[0]
+ if img.dim() == 5: # multi-view
+ num_cams = img.shape[1]
+ img = img.flatten(end_dim=1)
+ else:
+ num_cams = 1
+ if self.use_grid_mask:
+ img = self.grid_mask(img)
+ if "metas" in signature(self.img_backbone.forward).parameters:
+ feature_maps = self.img_backbone(img, num_cams, metas=metas)
+ else:
+ feature_maps = self.img_backbone(img)
+ if self.img_neck is not None:
+ feature_maps = list(self.img_neck(feature_maps))
+ for i, feat in enumerate(feature_maps):
+ feature_maps[i] = torch.reshape(
+ feat, (bs, num_cams) + feat.shape[1:]
+ )
+ if return_depth and self.depth_branch is not None:
+ depths = self.depth_branch(feature_maps, metas.get("focal"))
+ else:
+ depths = None
+ if self.use_deformable_func:
+ feature_maps = feature_maps_format(feature_maps)
+ if return_depth:
+ return feature_maps, depths
+ return feature_maps
+
+ @force_fp32(apply_to=("img",))
+ def forward(self, img, **data):
+ if self.training:
+ return self.forward_train(img, **data)
+ else:
+ return self.forward_test(img, **data)
+
+ def forward_train(self, img, **data):
+ feature_maps, depths = self.extract_feat(img, True, data)
+ model_outs = self.head(feature_maps, data)
+ output = self.head.loss(model_outs, data)
+ if depths is not None and "gt_depth" in data:
+ output["loss_dense_depth"] = self.depth_branch.loss(
+ depths, data["gt_depth"]
+ )
+ return output
+
+ def forward_test(self, img, **data):
+ if isinstance(img, list):
+ return self.aug_test(img, **data)
+ else:
+ return self.simple_test(img, **data)
+
+ def simple_test(self, img, **data):
+ feature_maps = self.extract_feat(img)
+
+ model_outs = self.head(feature_maps, data)
+ results = self.head.post_process(model_outs, data)
+ output = [{"img_bbox": result} for result in results]
+ return output
+
+ def aug_test(self, img, **data):
+ # fake test time augmentation
+ for key in data.keys():
+ if isinstance(data[key], list):
+ data[key] = data[key][0]
+ return self.simple_test(img[0], **data)
diff --git a/projects/mmdet3d_plugin/models/sparsedrive_head.py b/projects/mmdet3d_plugin/models/sparsedrive_head.py
new file mode 100644
index 0000000..79eec57
--- /dev/null
+++ b/projects/mmdet3d_plugin/models/sparsedrive_head.py
@@ -0,0 +1,124 @@
+from typing import List, Optional, Tuple, Union
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from mmcv.runner import BaseModule
+from mmdet.models import HEADS
+from mmdet.models import build_head
+
+
+@HEADS.register_module()
+class SparseDriveHead(BaseModule):
+ def __init__(
+ self,
+ task_config: dict,
+ det_head = dict,
+ map_head = dict,
+ motion_plan_head = dict,
+ init_cfg=None,
+ **kwargs,
+ ):
+ super(SparseDriveHead, self).__init__(init_cfg)
+ self.task_config = task_config
+ if self.task_config['with_det']:
+ self.det_head = build_head(det_head)
+ if self.task_config['with_map']:
+ self.map_head = build_head(map_head)
+ if self.task_config['with_motion_plan']:
+ self.motion_plan_head = build_head(motion_plan_head)
+
+ def init_weights(self):
+ if self.task_config['with_det']:
+ self.det_head.init_weights()
+ if self.task_config['with_map']:
+ self.map_head.init_weights()
+ if self.task_config['with_motion_plan']:
+ self.motion_plan_head.init_weights()
+
+ def forward(
+ self,
+ feature_maps: Union[torch.Tensor, List],
+ metas: dict,
+ ):
+ if self.task_config['with_det']:
+ det_output = self.det_head(feature_maps, metas)
+ else:
+ det_output = None
+
+ if self.task_config['with_map']:
+ map_output = self.map_head(feature_maps, metas)
+ else:
+ map_output = None
+
+ if self.task_config['with_motion_plan']:
+ motion_output, planning_output = self.motion_plan_head(
+ det_output,
+ map_output,
+ feature_maps,
+ metas,
+ self.det_head.anchor_encoder,
+ self.det_head.instance_bank.mask,
+ self.det_head.instance_bank.anchor_handler,
+ )
+ else:
+ motion_output, planning_output = None, None
+
+ return det_output, map_output, motion_output, planning_output
+
+ def loss(self, model_outs, data):
+ det_output, map_output, motion_output, planning_output = model_outs
+ losses = dict()
+ if self.task_config['with_det']:
+ loss_det = self.det_head.loss(det_output, data)
+ losses.update(loss_det)
+
+ if self.task_config['with_map']:
+ loss_map = self.map_head.loss(map_output, data)
+ losses.update(loss_map)
+
+ if self.task_config['with_motion_plan']:
+ motion_loss_cache = dict(
+ indices=self.det_head.sampler.indices,
+ )
+ loss_motion = self.motion_plan_head.loss(
+ motion_output,
+ planning_output,
+ data,
+ motion_loss_cache
+ )
+ losses.update(loss_motion)
+
+ return losses
+
+ def post_process(self, model_outs, data):
+ det_output, map_output, motion_output, planning_output = model_outs
+ if self.task_config['with_det']:
+ det_result = self.det_head.post_process(det_output)
+ batch_size = len(det_result)
+
+ if self.task_config['with_map']:
+ map_result= self.map_head.post_process(map_output)
+ batch_size = len(map_result)
+
+ if self.task_config['with_motion_plan']:
+ motion_result, planning_result = self.motion_plan_head.post_process(
+ det_output,
+ motion_output,
+ planning_output,
+ data,
+ )
+
+ results = [dict()] * batch_size
+ for i in range(batch_size):
+ if self.task_config['with_det']:
+ results[i].update(det_result[i])
+ if self.task_config['with_map']:
+ results[i].update(map_result[i])
+ if self.task_config['with_motion_plan']:
+ results[i].update(motion_result[i])
+ results[i].update(planning_result[i])
+
+ return results
diff --git a/projects/mmdet3d_plugin/ops/__init__.py b/projects/mmdet3d_plugin/ops/__init__.py
new file mode 100644
index 0000000..cf23848
--- /dev/null
+++ b/projects/mmdet3d_plugin/ops/__init__.py
@@ -0,0 +1,92 @@
+import torch
+
+from .deformable_aggregation import DeformableAggregationFunction
+
+
+def deformable_aggregation_function(
+ feature_maps,
+ spatial_shape,
+ scale_start_index,
+ sampling_location,
+ weights,
+):
+ return DeformableAggregationFunction.apply(
+ feature_maps,
+ spatial_shape,
+ scale_start_index,
+ sampling_location,
+ weights,
+ )
+
+
+def feature_maps_format(feature_maps, inverse=False):
+ if inverse:
+ col_feats, spatial_shape, scale_start_index = feature_maps
+ num_cams, num_levels = spatial_shape.shape[:2]
+
+ split_size = spatial_shape[..., 0] * spatial_shape[..., 1]
+ split_size = split_size.cpu().numpy().tolist()
+
+ idx = 0
+ cam_split = [1]
+ cam_split_size = [sum(split_size[0])]
+ for i in range(num_cams - 1):
+ if not torch.all(spatial_shape[i] == spatial_shape[i + 1]):
+ cam_split.append(0)
+ cam_split_size.append(0)
+ cam_split[-1] += 1
+ cam_split_size[-1] += sum(split_size[i + 1])
+ mc_feat = [
+ x.unflatten(1, (cam_split[i], -1))
+ for i, x in enumerate(col_feats.split(cam_split_size, dim=1))
+ ]
+
+ spatial_shape = spatial_shape.cpu().numpy().tolist()
+ mc_ms_feat = []
+ shape_index = 0
+ for i, feat in enumerate(mc_feat):
+ feat = list(feat.split(split_size[shape_index], dim=2))
+ for j, f in enumerate(feat):
+ feat[j] = f.unflatten(2, spatial_shape[shape_index][j])
+ feat[j] = feat[j].permute(0, 1, 4, 2, 3)
+ mc_ms_feat.append(feat)
+ shape_index += cam_split[i]
+ return mc_ms_feat
+
+ if isinstance(feature_maps[0], (list, tuple)):
+ formated = [feature_maps_format(x) for x in feature_maps]
+ col_feats = torch.cat([x[0] for x in formated], dim=1)
+ spatial_shape = torch.cat([x[1] for x in formated], dim=0)
+ scale_start_index = torch.cat([x[2] for x in formated], dim=0)
+ return [col_feats, spatial_shape, scale_start_index]
+
+ bs, num_cams = feature_maps[0].shape[:2]
+ spatial_shape = []
+
+ col_feats = []
+ for i, feat in enumerate(feature_maps):
+ spatial_shape.append(feat.shape[-2:])
+ col_feats.append(
+ torch.reshape(feat, (bs, num_cams, feat.shape[2], -1))
+ )
+
+ col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2)
+ spatial_shape = [spatial_shape] * num_cams
+ spatial_shape = torch.tensor(
+ spatial_shape,
+ dtype=torch.int64,
+ device=col_feats.device,
+ )
+ scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1]
+ scale_start_index = scale_start_index.flatten().cumsum(dim=0)
+ scale_start_index = torch.cat(
+ [torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]]
+ )
+ scale_start_index = scale_start_index.reshape(num_cams, -1)
+
+ feature_maps = [
+ col_feats,
+ spatial_shape,
+ scale_start_index,
+ ]
+ return feature_maps
diff --git a/projects/mmdet3d_plugin/ops/deformable_aggregation.py b/projects/mmdet3d_plugin/ops/deformable_aggregation.py
new file mode 100644
index 0000000..fe11b69
--- /dev/null
+++ b/projects/mmdet3d_plugin/ops/deformable_aggregation.py
@@ -0,0 +1,75 @@
+import torch
+from torch.autograd.function import Function, once_differentiable
+
+from . import deformable_aggregation_ext
+
+
+class DeformableAggregationFunction(Function):
+ @staticmethod
+ def forward(
+ ctx,
+ mc_ms_feat,
+ spatial_shape,
+ scale_start_index,
+ sampling_location,
+ weights,
+ ):
+ # output: [bs, num_pts, num_embeds]
+ mc_ms_feat = mc_ms_feat.contiguous().float()
+ spatial_shape = spatial_shape.contiguous().int()
+ scale_start_index = scale_start_index.contiguous().int()
+ sampling_location = sampling_location.contiguous().float()
+ weights = weights.contiguous().float()
+ output = deformable_aggregation_ext.deformable_aggregation_forward(
+ mc_ms_feat,
+ spatial_shape,
+ scale_start_index,
+ sampling_location,
+ weights,
+ )
+ ctx.save_for_backward(
+ mc_ms_feat,
+ spatial_shape,
+ scale_start_index,
+ sampling_location,
+ weights,
+ )
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ (
+ mc_ms_feat,
+ spatial_shape,
+ scale_start_index,
+ sampling_location,
+ weights,
+ ) = ctx.saved_tensors
+ mc_ms_feat = mc_ms_feat.contiguous().float()
+ spatial_shape = spatial_shape.contiguous().int()
+ scale_start_index = scale_start_index.contiguous().int()
+ sampling_location = sampling_location.contiguous().float()
+ weights = weights.contiguous().float()
+
+ grad_mc_ms_feat = torch.zeros_like(mc_ms_feat)
+ grad_sampling_location = torch.zeros_like(sampling_location)
+ grad_weights = torch.zeros_like(weights)
+ deformable_aggregation_ext.deformable_aggregation_backward(
+ mc_ms_feat,
+ spatial_shape,
+ scale_start_index,
+ sampling_location,
+ weights,
+ grad_output.contiguous(),
+ grad_mc_ms_feat,
+ grad_sampling_location,
+ grad_weights,
+ )
+ return (
+ grad_mc_ms_feat,
+ None,
+ None,
+ grad_sampling_location,
+ grad_weights,
+ )
diff --git a/projects/mmdet3d_plugin/ops/setup.py b/projects/mmdet3d_plugin/ops/setup.py
new file mode 100644
index 0000000..cbade27
--- /dev/null
+++ b/projects/mmdet3d_plugin/ops/setup.py
@@ -0,0 +1,60 @@
+import os
+
+import torch
+from setuptools import setup
+from torch.utils.cpp_extension import (
+ BuildExtension,
+ CppExtension,
+ CUDAExtension,
+)
+
+
+def make_cuda_ext(
+ name,
+ module,
+ sources,
+ sources_cuda=[],
+ extra_args=[],
+ extra_include_path=[],
+):
+
+ define_macros = []
+ extra_compile_args = {"cxx": [] + extra_args}
+
+ if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1":
+ define_macros += [("WITH_CUDA", None)]
+ extension = CUDAExtension
+ extra_compile_args["nvcc"] = extra_args + [
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ sources += sources_cuda
+ else:
+ print("Compiling {} without CUDA".format(name))
+ extension = CppExtension
+
+ return extension(
+ name="{}.{}".format(module, name),
+ sources=[os.path.join(*module.split("."), p) for p in sources],
+ include_dirs=extra_include_path,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+
+
+if __name__ == "__main__":
+ setup(
+ name="deformable_aggregation_ext",
+ ext_modules=[
+ make_cuda_ext(
+ "deformable_aggregation_ext",
+ module=".",
+ sources=[
+ f"src/deformable_aggregation.cpp",
+ f"src/deformable_aggregation_cuda.cu",
+ ],
+ ),
+ ],
+ cmdclass={"build_ext": BuildExtension},
+ )
diff --git a/projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp b/projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp
new file mode 100644
index 0000000..68356a7
--- /dev/null
+++ b/projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp
@@ -0,0 +1,138 @@
+#include
+#include
+
+void deformable_aggregation(
+ float* output,
+ const float* mc_ms_feat,
+ const int* spatial_shape,
+ const int* scale_start_index,
+ const float* sample_location,
+ const float* weights,
+ int batch_size,
+ int num_cams,
+ int num_feat,
+ int num_embeds,
+ int num_scale,
+ int num_anchors,
+ int num_pts,
+ int num_groups
+);
+
+
+/* feat: bs, num_feat, c */
+/* _spatial_shape: cam, scale, 2 */
+/* _scale_start_index: cam, scale */
+/* _sampling_location: bs, anchor, pts, cam, 2 */
+/* _weights: bs, anchor, pts, cam, scale, group */
+/* output: bs, anchor, c */
+/* kernel: bs, anchor, pts, c */
+
+
+at::Tensor deformable_aggregation_forward(
+ const at::Tensor &_mc_ms_feat,
+ const at::Tensor &_spatial_shape,
+ const at::Tensor &_scale_start_index,
+ const at::Tensor &_sampling_location,
+ const at::Tensor &_weights
+) {
+ at::DeviceGuard guard(_mc_ms_feat.device());
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
+ int batch_size = _mc_ms_feat.size(0);
+ int num_feat = _mc_ms_feat.size(1);
+ int num_embeds = _mc_ms_feat.size(2);
+ int num_cams = _spatial_shape.size(0);
+ int num_scale = _spatial_shape.size(1);
+ int num_anchors = _sampling_location.size(1);
+ int num_pts = _sampling_location.size(2);
+ int num_groups = _weights.size(5);
+
+ const float* mc_ms_feat = _mc_ms_feat.data_ptr();
+ const int* spatial_shape = _spatial_shape.data_ptr();
+ const int* scale_start_index = _scale_start_index.data_ptr();
+ const float* sampling_location = _sampling_location.data_ptr();
+ const float* weights = _weights.data_ptr();
+
+ auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options());
+ deformable_aggregation(
+ output.data_ptr(),
+ mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
+ batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
+ );
+ return output;
+}
+
+
+void deformable_aggregation_grad(
+ const float* mc_ms_feat,
+ const int* spatial_shape,
+ const int* scale_start_index,
+ const float* sample_location,
+ const float* weights,
+ const float* grad_output,
+ float* grad_mc_ms_feat,
+ float* grad_sampling_location,
+ float* grad_weights,
+ int batch_size,
+ int num_cams,
+ int num_feat,
+ int num_embeds,
+ int num_scale,
+ int num_anchors,
+ int num_pts,
+ int num_groups
+);
+
+
+void deformable_aggregation_backward(
+ const at::Tensor &_mc_ms_feat,
+ const at::Tensor &_spatial_shape,
+ const at::Tensor &_scale_start_index,
+ const at::Tensor &_sampling_location,
+ const at::Tensor &_weights,
+ const at::Tensor &_grad_output,
+ at::Tensor &_grad_mc_ms_feat,
+ at::Tensor &_grad_sampling_location,
+ at::Tensor &_grad_weights
+) {
+ at::DeviceGuard guard(_mc_ms_feat.device());
+ const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
+ int batch_size = _mc_ms_feat.size(0);
+ int num_feat = _mc_ms_feat.size(1);
+ int num_embeds = _mc_ms_feat.size(2);
+ int num_cams = _spatial_shape.size(0);
+ int num_scale = _spatial_shape.size(1);
+ int num_anchors = _sampling_location.size(1);
+ int num_pts = _sampling_location.size(2);
+ int num_groups = _weights.size(5);
+
+ const float* mc_ms_feat = _mc_ms_feat.data_ptr();
+ const int* spatial_shape = _spatial_shape.data_ptr();
+ const int* scale_start_index = _scale_start_index.data_ptr();
+ const float* sampling_location = _sampling_location.data_ptr();
+ const float* weights = _weights.data_ptr();
+ const float* grad_output = _grad_output.data_ptr();
+
+ float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr();
+ float* grad_sampling_location = _grad_sampling_location.data_ptr();
+ float* grad_weights = _grad_weights.data_ptr();
+
+ deformable_aggregation_grad(
+ mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
+ grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
+ batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
+ );
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def(
+ "deformable_aggregation_forward",
+ &deformable_aggregation_forward,
+ "deformable_aggregation_forward"
+ );
+ m.def(
+ "deformable_aggregation_backward",
+ &deformable_aggregation_backward,
+ "deformable_aggregation_backward"
+ );
+}
diff --git a/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu b/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu
new file mode 100644
index 0000000..4f748e5
--- /dev/null
+++ b/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu
@@ -0,0 +1,318 @@
+
+#include
+#include
+#include
+#include
+
+#include
+
+#include
+#include
+
+
+__device__ float bilinear_sampling(
+ const float *&bottom_data, const int &height, const int &width,
+ const int &num_embeds, const float &h_im, const float &w_im,
+ const int &base_ptr
+) {
+ const int h_low = floorf(h_im);
+ const int w_low = floorf(w_im);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const float lh = h_im - h_low;
+ const float lw = w_im - w_low;
+ const float hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = num_embeds;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+
+ float v1 = 0;
+ if (h_low >= 0 && w_low >= 0) {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ float v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1) {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ float v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0) {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ float v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1) {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+__device__ void bilinear_sampling_grad(
+ const float *&bottom_data, const float &weight,
+ const int &height, const int &width,
+ const int &num_embeds, const float &h_im, const float &w_im,
+ const int &base_ptr,
+ const float &grad_output,
+ float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
+ const int h_low = floorf(h_im);
+ const int w_low = floorf(w_im);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const float lh = h_im - h_low;
+ const float lw = w_im - w_low;
+ const float hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = num_embeds;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+
+ const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const float top_grad_mc_ms_feat = grad_output * weight;
+ float grad_h_weight = 0, grad_w_weight = 0;
+
+ float v1 = 0;
+ if (h_low >= 0 && w_low >= 0) {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
+ }
+ float v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1) {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
+ }
+ float v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0) {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
+ }
+ float v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1) {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
+ }
+
+ const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_weights, grad_output * val);
+ atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
+ atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
+}
+
+
+__global__ void deformable_aggregation_kernel(
+ const int num_kernels,
+ float* output,
+ const float* mc_ms_feat,
+ const int* spatial_shape,
+ const int* scale_start_index,
+ const float* sample_location,
+ const float* weights,
+ int batch_size,
+ int num_cams,
+ int num_feat,
+ int num_embeds,
+ int num_scale,
+ int num_anchors,
+ int num_pts,
+ int num_groups
+) {
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= num_kernels) return;
+
+ const float weight = *(weights + idx / (num_embeds / num_groups));
+ const int channel_index = idx % num_embeds;
+ idx /= num_embeds;
+ const int scale_index = idx % num_scale;
+ idx /= num_scale;
+
+ const int cam_index = idx % num_cams;
+ idx /= num_cams;
+ const int pts_index = idx % num_pts;
+ idx /= num_pts;
+
+ int anchor_index = idx % num_anchors;
+ idx /= num_anchors;
+ const int batch_index = idx % batch_size;
+ idx /= batch_size;
+
+ anchor_index = batch_index * num_anchors + anchor_index;
+ const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
+
+ const float loc_w = sample_location[loc_offset];
+ if (loc_w <= 0 || loc_w >= 1) return;
+ const float loc_h = sample_location[loc_offset + 1];
+ if (loc_h <= 0 || loc_h >= 1) return;
+
+ int cam_scale_index = cam_index * num_scale + scale_index;
+ const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
+
+ cam_scale_index = cam_scale_index << 1;
+ const int h = spatial_shape[cam_scale_index];
+ const int w = spatial_shape[cam_scale_index + 1];
+
+ const float h_im = loc_h * h - 0.5;
+ const float w_im = loc_w * w - 0.5;
+
+ atomicAdd(
+ output + anchor_index * num_embeds + channel_index,
+ bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
+ );
+}
+
+
+__global__ void deformable_aggregation_grad_kernel(
+ const int num_kernels,
+ const float* mc_ms_feat,
+ const int* spatial_shape,
+ const int* scale_start_index,
+ const float* sample_location,
+ const float* weights,
+ const float* grad_output,
+ float* grad_mc_ms_feat,
+ float* grad_sampling_location,
+ float* grad_weights,
+ int batch_size,
+ int num_cams,
+ int num_feat,
+ int num_embeds,
+ int num_scale,
+ int num_anchors,
+ int num_pts,
+ int num_groups
+) {
+ int idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (idx >= num_kernels) return;
+
+ const int weights_ptr = idx / (num_embeds / num_groups);
+ const int channel_index = idx % num_embeds;
+ idx /= num_embeds;
+ const int scale_index = idx % num_scale;
+ idx /= num_scale;
+
+ const int cam_index = idx % num_cams;
+ idx /= num_cams;
+ const int pts_index = idx % num_pts;
+ idx /= num_pts;
+
+ int anchor_index = idx % num_anchors;
+ idx /= num_anchors;
+ const int batch_index = idx % batch_size;
+ idx /= batch_size;
+
+ anchor_index = batch_index * num_anchors + anchor_index;
+ const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
+
+ const float loc_w = sample_location[loc_offset];
+ if (loc_w <= 0 || loc_w >= 1) return;
+ const float loc_h = sample_location[loc_offset + 1];
+ if (loc_h <= 0 || loc_h >= 1) return;
+
+ const float grad = grad_output[anchor_index*num_embeds + channel_index];
+
+ int cam_scale_index = cam_index * num_scale + scale_index;
+ const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
+
+ cam_scale_index = cam_scale_index << 1;
+ const int h = spatial_shape[cam_scale_index];
+ const int w = spatial_shape[cam_scale_index + 1];
+
+ const float h_im = loc_h * h - 0.5;
+ const float w_im = loc_w * w - 0.5;
+
+ /* atomicAdd( */
+ /* output + anchor_index * num_embeds + channel_index, */
+ /* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
+ /* ); */
+ const float weight = weights[weights_ptr];
+ float *grad_weights_ptr = grad_weights + weights_ptr;
+ float *grad_location_ptr = grad_sampling_location + loc_offset;
+ bilinear_sampling_grad(
+ mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
+ value_offset,
+ grad,
+ grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
+ );
+}
+
+
+void deformable_aggregation(
+ float* output,
+ const float* mc_ms_feat,
+ const int* spatial_shape,
+ const int* scale_start_index,
+ const float* sample_location,
+ const float* weights,
+ int batch_size,
+ int num_cams,
+ int num_feat,
+ int num_embeds,
+ int num_scale,
+ int num_anchors,
+ int num_pts,
+ int num_groups
+) {
+ const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
+ deformable_aggregation_kernel
+ <<<(int)ceil(((double)num_kernels/128)), 128>>>(
+ num_kernels, output,
+ mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
+ batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
+ );
+}
+
+
+void deformable_aggregation_grad(
+ const float* mc_ms_feat,
+ const int* spatial_shape,
+ const int* scale_start_index,
+ const float* sample_location,
+ const float* weights,
+ const float* grad_output,
+ float* grad_mc_ms_feat,
+ float* grad_sampling_location,
+ float* grad_weights,
+ int batch_size,
+ int num_cams,
+ int num_feat,
+ int num_embeds,
+ int num_scale,
+ int num_anchors,
+ int num_pts,
+ int num_groups
+) {
+ const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
+ deformable_aggregation_grad_kernel
+ <<<(int)ceil(((double)num_kernels/128)), 128>>>(
+ num_kernels,
+ mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
+ grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
+ batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
+ );
+}
diff --git a/requirement.txt b/requirement.txt
new file mode 100644
index 0000000..e18254d
--- /dev/null
+++ b/requirement.txt
@@ -0,0 +1,15 @@
+numpy==1.23.5
+mmcv_full==1.7.1
+mmdet==2.28.2
+urllib3==1.26.16
+pyquaternion==0.9.9
+nuscenes-devkit==1.1.10
+yapf==0.32
+tensorboard==2.14.0
+motmetrics==1.1.3
+pandas==1.1.5
+flash-attn==2.3.2
+opencv-python==4.8.1.78
+prettytable==3.7.0
+scikit-learn==1.3.0
+wandb==0.16.6
\ No newline at end of file
diff --git a/resources/legend.png b/resources/legend.png
new file mode 100644
index 0000000..c34e3b3
Binary files /dev/null and b/resources/legend.png differ
diff --git a/resources/motion_planner.png b/resources/motion_planner.png
new file mode 100644
index 0000000..a852fb6
Binary files /dev/null and b/resources/motion_planner.png differ
diff --git a/resources/overview.png b/resources/overview.png
new file mode 100644
index 0000000..cf15694
Binary files /dev/null and b/resources/overview.png differ
diff --git a/resources/sdc_car.png b/resources/sdc_car.png
new file mode 100644
index 0000000..df19cbc
Binary files /dev/null and b/resources/sdc_car.png differ
diff --git a/resources/sparse_perception.png b/resources/sparse_perception.png
new file mode 100644
index 0000000..5696c76
Binary files /dev/null and b/resources/sparse_perception.png differ
diff --git a/scripts/create_data.sh b/scripts/create_data.sh
new file mode 100644
index 0000000..63f7e97
--- /dev/null
+++ b/scripts/create_data.sh
@@ -0,0 +1,16 @@
+export PYTHONPATH="$(dirname $0)/..":$PYTHONPATH
+
+python tools/data_converter/nuscenes_converter.py nuscenes \
+ --root-path ./data/nuscenes \
+ --canbus ./data/nuscenes \
+ --out-dir ./data/infos/ \
+ --extra-tag nuscenes \
+ --version v1.0-mini
+
+python tools/data_converter/nuscenes_converter.py nuscenes \
+ --root-path ./data/nuscenes \
+ --canbus ./data/nuscenes \
+ --out-dir ./data/infos/ \
+ --extra-tag nuscenes \
+ --version v1.0
+
diff --git a/scripts/kmeans.sh b/scripts/kmeans.sh
new file mode 100644
index 0000000..6b01371
--- /dev/null
+++ b/scripts/kmeans.sh
@@ -0,0 +1,4 @@
+python tools/kmeans/kmeans_det.py
+python tools/kmeans/kmeans_map.py
+python tools/kmeans/kmeans_motion.py
+python tools/kmeans/kmeans_plan.py
\ No newline at end of file
diff --git a/scripts/test.sh b/scripts/test.sh
new file mode 100644
index 0000000..c6861d2
--- /dev/null
+++ b/scripts/test.sh
@@ -0,0 +1,7 @@
+bash ./tools/dist_test.sh \
+ projects/configs/sparsedrive_small_stage2.py \
+ ckpt/sparsedrive_stage2.pth \
+ 8 \
+ --deterministic \
+ --eval bbox
+ # --result_file ./work_dirs/sparsedrive_small_stage2/results.pkl
\ No newline at end of file
diff --git a/scripts/train.sh b/scripts/train.sh
new file mode 100644
index 0000000..c1f6171
--- /dev/null
+++ b/scripts/train.sh
@@ -0,0 +1,11 @@
+## stage1
+bash ./tools/dist_train.sh \
+ projects/configs/sparsedrive_small_stage1.py \
+ 8 \
+ --deterministic
+
+## stage2
+bash ./tools/dist_train.sh \
+ projects/configs/sparsedrive_small_stage2.py \
+ 8 \
+ --deterministic
\ No newline at end of file
diff --git a/scripts/visualize.sh b/scripts/visualize.sh
new file mode 100644
index 0000000..dbf439c
--- /dev/null
+++ b/scripts/visualize.sh
@@ -0,0 +1,4 @@
+export PYTHONPATH="$(dirname $0)/..":$PYTHONPATH
+python tools/visualization/visualize.py \
+ projects/configs/sparsedrive_small_stage2.py \
+ --result-path work_dirs/sparsedrive_small_stage2/results.pkl
\ No newline at end of file
diff --git a/tools/analyze_occ_mask.py b/tools/analyze_occ_mask.py
new file mode 100644
index 0000000..5efb399
--- /dev/null
+++ b/tools/analyze_occ_mask.py
@@ -0,0 +1,184 @@
+#!/usr/bin/env python3
+"""
+Compute statistics for NuScenes objects that don't meet the num_lidar_pts > 0
+mask criteria (i.e. num_lidar_pts <= 0) used in NuScenes3DDataset.get_ann_info().
+
+Usage:
+ python tools/analyze_num_lidar_pts_mask.py [ann_file]
+ python tools/analyze_num_lidar_pts_mask.py data/infos/nuscenes_infos_train.pkl
+ python tools/analyze_num_lidar_pts_mask.py data/infos/nuscenes_infos_val.pkl
+"""
+
+import argparse
+from collections import defaultdict
+
+import mmcv
+import numpy as np
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Stats for objects with num_lidar_pts <= 0 (filtered out by dataset mask)"
+ )
+ parser.add_argument(
+ "ann_file",
+ nargs="?",
+ default="data/infos/nuscenes_infos_train.pkl",
+ help="Path to NuScenes info pkl (default: data/infos/nuscenes_infos_train.pkl)",
+ )
+ parser.add_argument(
+ "--load-interval",
+ type=int,
+ default=1,
+ help="Same as dataset load_interval (default: 1)",
+ )
+ parser.add_argument(
+ "--by-class",
+ action="store_true",
+ help="Print per-class breakdown of filtered objects",
+ )
+ args = parser.parse_args()
+
+ print(f"Loading {args.ann_file} ...")
+ data = mmcv.load(args.ann_file, file_format="pkl")
+ infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+ infos = infos[:: args.load_interval]
+ print(f"Loaded {len(infos)} samples (load_interval={args.load_interval})\n")
+
+ total_objs = 0
+ filtered_out = 0 # num_lidar_pts <= 0
+ samples_with_any_filtered = 0
+ samples_with_all_filtered = 0
+ # num_lidar_pts value distribution for filtered objects (0, -1, etc.)
+ filtered_pts_dist = defaultdict(int)
+ # per-class: (total, filtered_out)
+ class_total = defaultdict(int)
+ class_filtered = defaultdict(int)
+
+ for info in infos:
+ num_pts = np.asarray(info["num_lidar_pts"])
+ gt_names = info["gt_names"]
+ n = len(num_pts)
+ if n == 0:
+ continue
+
+ total_objs += n
+ mask_kept = num_pts > 0
+ mask_filtered = ~mask_kept
+ n_filtered = mask_filtered.sum()
+ filtered_out += n_filtered
+
+ if n_filtered > 0:
+ samples_with_any_filtered += 1
+ if n_filtered == n:
+ samples_with_all_filtered += 1
+
+ for v in num_pts[mask_filtered]:
+ filtered_pts_dist[int(v)] += 1
+
+ for i in range(n):
+ name = gt_names[i] if isinstance(gt_names[i], str) else gt_names[i].item()
+ class_total[name] += 1
+ if num_pts[i] <= 0:
+ class_filtered[name] += 1
+
+ # Summary
+ print("=" * 60)
+ print("Summary (mask: num_lidar_pts > 0)")
+ print("=" * 60)
+ print(f" Total objects: {total_objs}")
+ print(f" Filtered (num_lidar_pts <= 0): {filtered_out}")
+ pct = 100.0 * filtered_out / total_objs
+ print(f" Filtered %: {pct:.2f}%")
+
+ if args.by_class and class_total:
+ print("\n" + "=" * 60)
+ print("Per-class (filtered = num_lidar_pts <= 0)")
+ print("=" * 60)
+ for name in sorted(class_total.keys()):
+ tot = class_total[name]
+ flt = class_filtered[name]
+ pct = 100.0 * flt / tot if tot else 0
+ print(f" {name:25s} total: {tot:6d} filtered: {flt:6d} ({pct:5.2f}%)")
+
+ # -------------------------------------------------------------------------
+ # Track-level stats: per track, how many annotations are not present at all
+ # in the infos (instance not labelled in that sample) after the first
+ # sample where the track appears.
+ # -------------------------------------------------------------------------
+ if "instance_inds" in infos[0]:
+ # scene_token -> sorted list of sample_idx belonging to that scene
+ scene_to_sample_indices = defaultdict(list)
+ # (scene_token, instance_ind) -> list of sample_idx where instance is annotated
+ tracks = defaultdict(lambda: defaultdict(list))
+
+ for sample_idx, info in enumerate(infos):
+ scene_token = info["scene_token"]
+ scene_to_sample_indices[scene_token].append(sample_idx)
+ instance_inds = info.get("instance_inds", [])
+ if len(instance_inds) == 0:
+ continue
+ instance_inds = np.asarray(instance_inds)
+ for i in range(len(instance_inds)):
+ instance_ind = int(instance_inds[i])
+ tracks[scene_token][instance_ind].append(sample_idx)
+
+ for scene_token in scene_to_sample_indices:
+ scene_to_sample_indices[scene_token].sort()
+
+ # For each track: span = [first_sample, last_sample]. Missing = number of
+ # samples in that span (in this scene) where the instance is not in the infos.
+ total_tracks = 0
+ total_missing_not_labelled = 0
+ total_missing_after_track_end = 0
+ tracks_with_any_missing = 0
+ missing_per_track_dist = defaultdict(int)
+ after_end_per_track_dist = defaultdict(int)
+
+ for scene_token, instances in tracks.items():
+ scene_samples = scene_to_sample_indices[scene_token]
+ for instance_ind, appearances in instances.items():
+ if len(appearances) == 0:
+ continue
+ appearances = sorted(appearances)
+ first_sample = appearances[0]
+ last_sample = appearances[-1]
+ # Samples in this scene that fall in the track span
+ span_samples = [s for s in scene_samples if first_sample <= s <= last_sample]
+ expected_frames = len(span_samples)
+ present_frames = len(appearances)
+ missing = expected_frames - present_frames
+
+ # Samples in this scene after the last labelled sample for this track
+ after_end = sum(1 for s in scene_samples if s > last_sample)
+ total_missing_after_track_end += after_end
+ after_end_per_track_dist[after_end] += 1
+
+ total_tracks += 1
+ total_missing_not_labelled += missing
+ if missing > 0:
+ tracks_with_any_missing += 1
+ missing_per_track_dist[missing] += 1
+
+ print("\n" + "=" * 60)
+ print("Track stats (missing = not present/labelled in infos at all)")
+ print("=" * 60)
+ print(f" Total tracks: {total_tracks}")
+ print(f" Total annotations missing mid track span (not in infos): {total_missing_not_labelled}")
+ print(f" Total annotations missing after track end (not in infos): {total_missing_after_track_end}")
+ print("\n Distribution of # missing per track mid span (not in infos):")
+ for k in sorted(missing_per_track_dist.keys()):
+ n_tracks = missing_per_track_dist[k]
+ print(f" {k:3d} missing: {n_tracks:6d} tracks")
+ print("\n Distribution of # missing after track end per track:")
+ for k in sorted(after_end_per_track_dist.keys()):
+ n_tracks = after_end_per_track_dist[k]
+ print(f" {k:3d} after end: {n_tracks:6d} tracks")
+ else:
+ print("\n (Skipping track stats: no 'instance_inds' in infos)")
+
+ print()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/benchmark.py b/tools/benchmark.py
new file mode 100644
index 0000000..e8b3511
--- /dev/null
+++ b/tools/benchmark.py
@@ -0,0 +1,197 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import time
+import torch
+from mmcv import Config
+from mmcv.parallel import MMDataParallel
+from mmcv.runner import load_checkpoint, wrap_fp16_model
+import sys
+sys.path.append('.')
+from projects.mmdet3d_plugin.datasets.builder import build_dataloader
+from projects.mmdet3d_plugin.datasets import custom_build_dataset
+from mmdet.models import build_detector
+from mmcv.cnn.utils.flops_counter import add_flops_counting_methods
+from mmcv.parallel import scatter
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='MMDet benchmark a model')
+ parser.add_argument('config', help='test config file path')
+ parser.add_argument('--checkpoint', default=None, help='checkpoint file')
+ parser.add_argument('--samples', default=1000, help='samples to benchmark')
+ parser.add_argument(
+ '--log-interval', default=50, help='interval of logging')
+ parser.add_argument(
+ '--fuse-conv-bn',
+ action='store_true',
+ help='Whether to fuse conv and bn, this will slightly increase'
+ 'the inference speed')
+ args = parser.parse_args()
+ return args
+
+
+def get_max_memory(model):
+ device = getattr(model, 'output_device', None)
+ mem = torch.cuda.max_memory_allocated(device=device)
+ mem_mb = torch.tensor([mem / (1024 * 1024)],
+ dtype=torch.int,
+ device=device)
+ return mem_mb.item()
+
+
+def main():
+ args = parse_args()
+ get_flops_params(args)
+ get_mem_fps(args)
+
+def get_mem_fps(args):
+ cfg = Config.fromfile(args.config)
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+ cfg.model.pretrained = None
+ cfg.data.test.test_mode = True
+
+ # build the dataloader
+ # TODO: support multiple images per gpu (only minor changes are needed)
+ print(cfg.data.test)
+ dataset = custom_build_dataset(cfg.data.test)
+ data_loader = build_dataloader(
+ dataset,
+ samples_per_gpu=1,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=False,
+ shuffle=False)
+
+ # build the model and load checkpoint
+ cfg.model.train_cfg = None
+ model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ wrap_fp16_model(model)
+ if args.checkpoint is not None:
+ load_checkpoint(model, args.checkpoint, map_location='cpu')
+ # if args.fuse_conv_bn:
+ # model = fuse_module(model)
+
+ model = MMDataParallel(model, device_ids=[0])
+
+ model.eval()
+
+ # the first several iterations may be very slow so skip them
+ num_warmup = 5
+ pure_inf_time = 0
+
+ # benchmark with several samples and take the average
+ max_memory = 0
+ for i, data in enumerate(data_loader):
+ # torch.cuda.synchronize()
+ with torch.no_grad():
+ start_time = time.perf_counter()
+ model(return_loss=False, rescale=True, **data)
+
+ torch.cuda.synchronize()
+ elapsed = time.perf_counter() - start_time
+ max_memory = max(max_memory, get_max_memory(model))
+
+ if i >= num_warmup:
+ pure_inf_time += elapsed
+ if (i + 1) % args.log_interval == 0:
+ fps = (i + 1 - num_warmup) / pure_inf_time
+ print(f'Done image [{i + 1:<3}/ {args.samples}], '
+ f'fps: {fps:.1f} img / s, '
+ f"gpu mem: {max_memory} M")
+
+ if (i + 1) == args.samples:
+ pure_inf_time += elapsed
+ fps = (i + 1 - num_warmup) / pure_inf_time
+ print(f'Overall fps: {fps:.1f} img / s')
+ break
+
+
+def get_flops_params(args):
+ gpu_id = 0
+ cfg = Config.fromfile(args.config)
+ dataset = custom_build_dataset(cfg.data.val)
+ dataloader = build_dataloader(
+ dataset,
+ samples_per_gpu=1,
+ workers_per_gpu=0,
+ dist=False,
+ shuffle=False,
+ )
+ data_iter = dataloader.__iter__()
+ data = next(data_iter)
+ data = scatter(data, [gpu_id])[0]
+
+ cfg.model.train_cfg = None
+ model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ wrap_fp16_model(model)
+ if args.checkpoint is not None:
+ load_checkpoint(model, args.checkpoint, map_location='cpu')
+ model = model.cuda(gpu_id)
+ model.eval()
+
+ bilinear_flops = 11
+ num_key_pts_det = (
+ cfg.model["head"]['det_head']["deformable_model"]["kps_generator"]["num_learnable_pts"]
+ + len(cfg.model["head"]['det_head']["deformable_model"]["kps_generator"]["fix_scale"])
+ )
+ deformable_agg_flops_det = (
+ cfg.num_decoder
+ * cfg.embed_dims
+ * cfg.num_levels
+ * cfg.model["head"]['det_head']["instance_bank"]["num_anchor"]
+ * cfg.model["head"]['det_head']["deformable_model"]["num_cams"]
+ * num_key_pts_det
+ * bilinear_flops
+ )
+ num_key_pts_map = (
+ cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["num_learnable_pts"]
+ + len(cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["fix_height"])
+ ) * cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["num_sample"]
+ deformable_agg_flops_map = (
+ cfg.num_decoder
+ * cfg.embed_dims
+ * cfg.num_levels
+ * cfg.model["head"]['map_head']["instance_bank"]["num_anchor"]
+ * cfg.model["head"]['map_head']["deformable_model"]["num_cams"]
+ * num_key_pts_map
+ * bilinear_flops
+ )
+ deformable_agg_flops = deformable_agg_flops_det + deformable_agg_flops_map
+
+ for module in ["total", "img_backbone", "img_neck", "head"]:
+ if module != "total":
+ flops_model = add_flops_counting_methods(getattr(model, module))
+ else:
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval()
+ flops_model.start_flops_count()
+
+ if module == "img_backbone":
+ flops_model(data["img"].flatten(0, 1))
+ elif module == "img_neck":
+ flops_model(model.img_backbone(data["img"].flatten(0, 1)))
+ elif module == "head":
+ flops_model(model.extract_feat(data["img"], metas=data), data)
+ else:
+ flops_model(**data)
+ flops_count, params_count = flops_model.compute_average_flops_cost()
+ flops_count *= flops_model.__batch_counter__
+ flops_model.stop_flops_count()
+ if module == "head" or module == "total":
+ flops_count += deformable_agg_flops
+ if module == "total":
+ total_flops = flops_count
+ total_params = params_count
+ print(
+ f"{module:<13} complexity: "
+ f"FLOPs={flops_count/ 10.**9:>8.4f} G / {flops_count/total_flops*100:>6.2f}%, "
+ f"Params={params_count/10**6:>8.4f} M / {params_count/total_params*100:>6.2f}%."
+ )
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/data_converter/__init__.py b/tools/data_converter/__init__.py
new file mode 100755
index 0000000..ef101fe
--- /dev/null
+++ b/tools/data_converter/__init__.py
@@ -0,0 +1 @@
+# Copyright (c) OpenMMLab. All rights reserved.
diff --git a/tools/data_converter/nuscenes_converter.py b/tools/data_converter/nuscenes_converter.py
new file mode 100755
index 0000000..f7d346d
--- /dev/null
+++ b/tools/data_converter/nuscenes_converter.py
@@ -0,0 +1,602 @@
+import os
+import math
+import copy
+import argparse
+from os import path as osp
+from collections import OrderedDict
+from typing import List, Tuple, Union
+
+import numpy as np
+from pyquaternion import Quaternion
+from shapely.geometry import MultiPoint, box
+
+import mmcv
+
+from nuscenes.nuscenes import NuScenes
+from nuscenes.can_bus.can_bus_api import NuScenesCanBus
+from nuscenes.utils.geometry_utils import transform_matrix
+from nuscenes.utils.data_classes import Box
+from nuscenes.utils.geometry_utils import view_points
+from nuscenes.prediction import PredictHelper, convert_local_coords_to_global
+
+from projects.mmdet3d_plugin.datasets.map_utils.nuscmap_extractor import NuscMapExtractor
+
+NameMapping = {
+ "movable_object.barrier": "barrier",
+ "vehicle.bicycle": "bicycle",
+ "vehicle.bus.bendy": "bus",
+ "vehicle.bus.rigid": "bus",
+ "vehicle.car": "car",
+ "vehicle.construction": "construction_vehicle",
+ "vehicle.motorcycle": "motorcycle",
+ "human.pedestrian.adult": "pedestrian",
+ "human.pedestrian.child": "pedestrian",
+ "human.pedestrian.construction_worker": "pedestrian",
+ "human.pedestrian.police_officer": "pedestrian",
+ "movable_object.trafficcone": "traffic_cone",
+ "vehicle.trailer": "trailer",
+ "vehicle.truck": "truck",
+}
+
+def quart_to_rpy(qua):
+ x, y, z, w = qua
+ roll = math.atan2(2 * (w * x + y * z), 1 - 2 * (x * x + y * y))
+ pitch = math.asin(2 * (w * y - x * z))
+ yaw = math.atan2(2 * (w * z + x * y), 1 - 2 * (z * z + y * y))
+ return roll, pitch, yaw
+
+def locate_message(utimes, utime):
+ i = np.searchsorted(utimes, utime)
+ if i == len(utimes) or (i > 0 and utime - utimes[i-1] < utimes[i] - utime):
+ i -= 1
+ return i
+
+def geom2anno(map_geoms):
+ MAP_CLASSES = (
+ 'ped_crossing',
+ 'divider',
+ 'boundary',
+ )
+ vectors = {}
+ for cls, geom_list in map_geoms.items():
+ if cls in MAP_CLASSES:
+ label = MAP_CLASSES.index(cls)
+ vectors[label] = []
+ for geom in geom_list:
+ line = np.array(geom.coords)
+ vectors[label].append(line)
+ return vectors
+
+def create_nuscenes_infos(root_path,
+ out_path,
+ can_bus_root_path,
+ info_prefix,
+ version='v1.0-trainval',
+ max_sweeps=10,
+ roi_size=(30, 60),):
+ """Create info file of nuscene dataset.
+
+ Given the raw data, generate its related info file in pkl format.
+
+ Args:
+ root_path (str): Path of the data root.
+ info_prefix (str): Prefix of the info file to be generated.
+ version (str): Version of the data.
+ Default: 'v1.0-trainval'
+ max_sweeps (int): Max number of sweeps.
+ Default: 10
+ """
+ print(version, root_path)
+ nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
+ nusc_map_extractor = NuscMapExtractor(root_path, roi_size)
+ nusc_can_bus = NuScenesCanBus(dataroot=can_bus_root_path)
+ from nuscenes.utils import splits
+ available_vers = ['v1.0-trainval', 'v1.0-test', 'v1.0-mini']
+ assert version in available_vers
+ if version == 'v1.0-trainval':
+ train_scenes = splits.train
+ val_scenes = splits.val
+ elif version == 'v1.0-test':
+ train_scenes = splits.test
+ val_scenes = []
+ elif version == 'v1.0-mini':
+ train_scenes = splits.mini_train
+ val_scenes = splits.mini_val
+ out_path = osp.join(out_path, 'mini')
+ else:
+ raise ValueError('unknown')
+ os.makedirs(out_path, exist_ok=True)
+
+ # filter existing scenes.
+ available_scenes = get_available_scenes(nusc)
+ available_scene_names = [s['name'] for s in available_scenes]
+ train_scenes = list(
+ filter(lambda x: x in available_scene_names, train_scenes))
+ val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes))
+ train_scenes = set([
+ available_scenes[available_scene_names.index(s)]['token']
+ for s in train_scenes
+ ])
+ val_scenes = set([
+ available_scenes[available_scene_names.index(s)]['token']
+ for s in val_scenes
+ ])
+
+ test = 'test' in version
+ if test:
+ print('test scene: {}'.format(len(train_scenes)))
+ else:
+ print('train scene: {}, val scene: {}'.format(
+ len(train_scenes), len(val_scenes)))
+
+ train_nusc_infos, val_nusc_infos = _fill_trainval_infos(
+ nusc, nusc_map_extractor, nusc_can_bus, train_scenes, val_scenes, test, max_sweeps=max_sweeps)
+
+ metadata = dict(version=version)
+ if test:
+ print('test sample: {}'.format(len(train_nusc_infos)))
+ data = dict(infos=train_nusc_infos, metadata=metadata)
+ info_path = osp.join(out_path,
+ '{}_infos_test.pkl'.format(info_prefix))
+ mmcv.dump(data, info_path)
+ else:
+ print('train sample: {}, val sample: {}'.format(
+ len(train_nusc_infos), len(val_nusc_infos)))
+ data = dict(infos=train_nusc_infos, metadata=metadata)
+ info_path = osp.join(out_path,
+ '{}_infos_train.pkl'.format(info_prefix))
+ mmcv.dump(data, info_path)
+ data['infos'] = val_nusc_infos
+ info_val_path = osp.join(out_path,
+ '{}_infos_val.pkl'.format(info_prefix))
+ mmcv.dump(data, info_val_path)
+
+def get_available_scenes(nusc):
+ """Get available scenes from the input nuscenes class.
+
+ Given the raw data, get the information of available scenes for
+ further info generation.
+
+ Args:
+ nusc (class): Dataset class in the nuScenes dataset.
+
+ Returns:
+ available_scenes (list[dict]): List of basic information for the
+ available scenes.
+ """
+ available_scenes = []
+ print('total scene num: {}'.format(len(nusc.scene)))
+ for scene in nusc.scene:
+ scene_token = scene['token']
+ scene_rec = nusc.get('scene', scene_token)
+ sample_rec = nusc.get('sample', scene_rec['first_sample_token'])
+ sd_rec = nusc.get('sample_data', sample_rec['data']['LIDAR_TOP'])
+ has_more_frames = True
+ scene_not_exist = False
+ while has_more_frames:
+ lidar_path, boxes, _ = nusc.get_sample_data(sd_rec['token'])
+ lidar_path = str(lidar_path)
+ if os.getcwd() in lidar_path:
+ # path from lyftdataset is absolute path
+ lidar_path = lidar_path.split(f'{os.getcwd()}/')[-1]
+ # relative path
+ if not mmcv.is_filepath(lidar_path):
+ scene_not_exist = True
+ break
+ else:
+ break
+ if scene_not_exist:
+ continue
+ available_scenes.append(scene)
+ print('exist scene num: {}'.format(len(available_scenes)))
+ return available_scenes
+
+def _fill_trainval_infos(nusc,
+ nusc_map_extractor,
+ nusc_can_bus,
+ train_scenes,
+ val_scenes,
+ test=False,
+ max_sweeps=10,
+ fut_ts=12,
+ ego_fut_ts=6):
+ """Generate the train/val infos from the raw data.
+
+ Args:
+ nusc (:obj:`NuScenes`): Dataset class in the nuScenes dataset.
+ train_scenes (list[str]): Basic information of training scenes.
+ val_scenes (list[str]): Basic information of validation scenes.
+ test (bool): Whether use the test mode. In the test mode, no
+ annotations can be accessed. Default: False.
+ max_sweeps (int): Max number of sweeps. Default: 10.
+
+ Returns:
+ tuple[list[dict]]: Information of training set and validation set
+ that will be saved to the info file.
+ """
+ train_nusc_infos = []
+ val_nusc_infos = []
+ cat2idx = {}
+ for idx, dic in enumerate(nusc.category):
+ cat2idx[dic['name']] = idx
+
+ predict_helper = PredictHelper(nusc)
+ for sample in mmcv.track_iter_progress(nusc.sample):
+ map_location = nusc.get('log', nusc.get('scene', sample['scene_token'])['log_token'])['location']
+ lidar_token = sample['data']['LIDAR_TOP']
+ sd_rec = nusc.get('sample_data', lidar_token)
+ cs_record = nusc.get('calibrated_sensor',
+ sd_rec['calibrated_sensor_token'])
+ pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
+ lidar_path, boxes, _ = nusc.get_sample_data(lidar_token)
+ mmcv.check_file_exist(lidar_path)
+
+ info = {
+ 'lidar_path': lidar_path,
+ 'token': sample['token'],
+ 'sweeps': [],
+ 'cams': dict(),
+ 'scene_token': sample['scene_token'],
+ 'lidar2ego_translation': cs_record['translation'],
+ 'lidar2ego_rotation': cs_record['rotation'],
+ 'ego2global_translation': pose_record['translation'],
+ 'ego2global_rotation': pose_record['rotation'],
+ 'timestamp': sample['timestamp'],
+ 'map_location': map_location,
+ }
+
+ l2e_r = info['lidar2ego_rotation']
+ l2e_t = info['lidar2ego_translation']
+ e2g_r = info['ego2global_rotation']
+ e2g_t = info['ego2global_translation']
+ l2e_r_mat = Quaternion(l2e_r).rotation_matrix
+ e2g_r_mat = Quaternion(e2g_r).rotation_matrix
+
+ # extract map annos
+ lidar2ego = np.eye(4)
+ lidar2ego[:3, :3] = Quaternion(
+ info["lidar2ego_rotation"]
+ ).rotation_matrix
+ lidar2ego[:3, 3] = np.array(info["lidar2ego_translation"])
+ ego2global = np.eye(4)
+ ego2global[:3, :3] = Quaternion(
+ info["ego2global_rotation"]
+ ).rotation_matrix
+ ego2global[:3, 3] = np.array(info["ego2global_translation"])
+ lidar2global = ego2global @ lidar2ego
+
+ translation = list(lidar2global[:3, 3])
+ rotation = list(Quaternion(matrix=lidar2global).q)
+ map_geoms = nusc_map_extractor.get_map_geom(map_location, translation, rotation)
+ map_annos = geom2anno(map_geoms)
+ info['map_annos'] = map_annos
+
+ # obtain 6 image's information per frame
+ camera_types = [
+ 'CAM_FRONT',
+ 'CAM_FRONT_RIGHT',
+ 'CAM_FRONT_LEFT',
+ 'CAM_BACK',
+ 'CAM_BACK_LEFT',
+ 'CAM_BACK_RIGHT',
+ ]
+ for cam in camera_types:
+ cam_token = sample['data'][cam]
+ cam_path, _, cam_intrinsic = nusc.get_sample_data(cam_token)
+ cam_info = obtain_sensor2top(nusc, cam_token, l2e_t, l2e_r_mat,
+ e2g_t, e2g_r_mat, cam)
+ cam_info.update(cam_intrinsic=cam_intrinsic)
+ info['cams'].update({cam: cam_info})
+
+ # obtain sweeps for a single key-frame
+ sd_rec = nusc.get('sample_data', sample['data']['LIDAR_TOP'])
+ sweeps = []
+ while len(sweeps) < max_sweeps:
+ if not sd_rec['prev'] == '':
+ sweep = obtain_sensor2top(nusc, sd_rec['prev'], l2e_t,
+ l2e_r_mat, e2g_t, e2g_r_mat, 'lidar')
+ sweeps.append(sweep)
+ sd_rec = nusc.get('sample_data', sd_rec['prev'])
+ else:
+ break
+ info['sweeps'] = sweeps
+ # obtain annotation
+ if not test:
+ # object detection annos: boxes (locs, dims, yaw, velocity), names and valid flags
+ annotations = [
+ nusc.get('sample_annotation', token)
+ for token in sample['anns']
+ ]
+ locs = np.array([b.center for b in boxes]).reshape(-1, 3)
+ dims = np.array([b.wlh for b in boxes]).reshape(-1, 3)
+ rots = np.array([b.orientation.yaw_pitch_roll[0]
+ for b in boxes]).reshape(-1, 1)
+ velocity = np.array(
+ [nusc.box_velocity(token)[:2] for token in sample['anns']])
+ # convert velo from global to lidar
+ for i in range(len(boxes)):
+ velo = np.array([*velocity[i], 0.0])
+ velo = velo @ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(
+ l2e_r_mat).T
+ velocity[i] = velo[:2]
+ names = [b.name for b in boxes]
+ for i in range(len(names)):
+ if names[i] in NameMapping:
+ names[i] = NameMapping[names[i]]
+ names = np.array(names)
+ valid_flag = np.array(
+ [(anno['num_lidar_pts'] + anno['num_radar_pts']) > 0
+ for anno in annotations],
+ dtype=bool).reshape(-1) ## TODO update valid flag for tracking
+ # we need to convert box size to
+ # the format of our lidar coordinate system
+ # which is x_size, y_size, z_size (corresponding to l, w, h)
+ gt_boxes = np.concatenate([locs, dims[:, [1, 0, 2]], rots], axis=1)
+ assert len(gt_boxes) == len(
+ annotations), f'{len(gt_boxes)}, {len(annotations)}'
+
+ # object tracking annos: instance_ids
+ instance_inds = [nusc.getind('instance', anno['instance_token'])
+ for anno in annotations]
+
+ # motion prediction annos: future trajectories offset in lidar frame and valid mask
+ num_box = len(boxes)
+ gt_fut_trajs = np.zeros((num_box, fut_ts, 2))
+ gt_fut_masks = np.zeros((num_box, fut_ts))
+ for i, anno in enumerate(annotations):
+ instance_token = anno['instance_token']
+ fut_traj_local = predict_helper.get_future_for_agent(
+ instance_token,
+ sample['token'],
+ seconds=fut_ts/2,
+ in_agent_frame=True
+ )
+ if fut_traj_local.shape[0] > 0:
+ box = boxes[i]
+ trans = box.center
+ rot = Quaternion(matrix=box.rotation_matrix)
+ fut_traj_scene = convert_local_coords_to_global(fut_traj_local, trans, rot)
+ valid_step = fut_traj_scene.shape[0]
+ gt_fut_trajs[i, 0] = fut_traj_scene[0] - box.center[:2]
+ gt_fut_trajs[i, 1:valid_step] = fut_traj_scene[1:] - fut_traj_scene[:-1]
+ gt_fut_masks[i, :valid_step] = 1
+
+ # motion planning annos: future trajectories offset in lidar frame and valid mask
+ ego_fut_trajs = np.zeros((ego_fut_ts + 1, 3))
+ ego_fut_masks = np.zeros((ego_fut_ts + 1))
+ sample_cur = sample
+ ego_status = get_ego_status(nusc, nusc_can_bus, sample_cur)
+ for i in range(ego_fut_ts + 1):
+ pose_mat = get_global_sensor_pose(sample_cur, nusc)
+ ego_fut_trajs[i] = pose_mat[:3, 3]
+ ego_fut_masks[i] = 1
+ if sample_cur['next'] == '':
+ ego_fut_trajs[i+1:] = ego_fut_trajs[i]
+ break
+ else:
+ sample_cur = nusc.get('sample', sample_cur['next'])
+ # global to ego
+ ego_fut_trajs = ego_fut_trajs - np.array(pose_record['translation'])
+ rot_mat = Quaternion(pose_record['rotation']).inverse.rotation_matrix
+ ego_fut_trajs = np.dot(rot_mat, ego_fut_trajs.T).T
+ # ego to lidar
+ ego_fut_trajs = ego_fut_trajs - np.array(cs_record['translation'])
+ rot_mat = Quaternion(cs_record['rotation']).inverse.rotation_matrix
+ ego_fut_trajs = np.dot(rot_mat, ego_fut_trajs.T).T
+ # drive command according to final fut step offset
+ if ego_fut_trajs[-1][0] >= 2:
+ command = np.array([1, 0, 0]) # Turn Right
+ elif ego_fut_trajs[-1][0] <= -2:
+ command = np.array([0, 1, 0]) # Turn Left
+ else:
+ command = np.array([0, 0, 1]) # Go Straight
+ # get offset
+ ego_fut_trajs = ego_fut_trajs[1:] - ego_fut_trajs[:-1]
+
+ info['gt_boxes'] = gt_boxes
+ info['gt_names'] = names
+ info['gt_velocity'] = velocity.reshape(-1, 2)
+ info['num_lidar_pts'] = np.array(
+ [a['num_lidar_pts'] for a in annotations])
+ info['num_radar_pts'] = np.array(
+ [a['num_radar_pts'] for a in annotations])
+ info['valid_flag'] = valid_flag
+ info['instance_inds'] = instance_inds
+ info['gt_agent_fut_trajs'] = gt_fut_trajs.astype(np.float32)
+ info['gt_agent_fut_masks'] = gt_fut_masks.astype(np.float32)
+ info['gt_ego_fut_trajs'] = ego_fut_trajs[:, :2].astype(np.float32)
+ info['gt_ego_fut_masks'] = ego_fut_masks[1:].astype(np.float32)
+ info['gt_ego_fut_cmd'] = command.astype(np.float32)
+ info['ego_status'] = ego_status
+
+ if sample['scene_token'] in train_scenes:
+ train_nusc_infos.append(info)
+ else:
+ val_nusc_infos.append(info)
+
+ return train_nusc_infos, val_nusc_infos
+
+def get_ego_status(nusc, nusc_can_bus, sample):
+ ego_status = []
+ ref_scene = nusc.get("scene", sample['scene_token'])
+ try:
+ pose_msgs = nusc_can_bus.get_messages(ref_scene['name'],'pose')
+ steer_msgs = nusc_can_bus.get_messages(ref_scene['name'], 'steeranglefeedback')
+ pose_uts = [msg['utime'] for msg in pose_msgs]
+ steer_uts = [msg['utime'] for msg in steer_msgs]
+ ref_utime = sample['timestamp']
+ pose_index = locate_message(pose_uts, ref_utime)
+ pose_data = pose_msgs[pose_index]
+ steer_index = locate_message(steer_uts, ref_utime)
+ steer_data = steer_msgs[steer_index]
+ ego_status.extend(pose_data["accel"]) # acceleration in ego vehicle frame, m/s/s
+ ego_status.extend(pose_data["rotation_rate"]) # angular velocity in ego vehicle frame, rad/s
+ ego_status.extend(pose_data["vel"]) # velocity in ego vehicle frame, m/s
+ ego_status.append(steer_data["value"]) # steering angle, positive: left turn, negative: right turn
+ except:
+ ego_status = [0] * 10
+
+ return np.array(ego_status).astype(np.float32)
+
+def get_global_sensor_pose(rec, nusc):
+ lidar_sample_data = nusc.get('sample_data', rec['data']['LIDAR_TOP'])
+
+ pose_record = nusc.get("ego_pose", lidar_sample_data["ego_pose_token"])
+ cs_record = nusc.get("calibrated_sensor", lidar_sample_data["calibrated_sensor_token"])
+
+ ego2global = transform_matrix(pose_record["translation"], Quaternion(pose_record["rotation"]), inverse=False)
+ sensor2ego = transform_matrix(cs_record["translation"], Quaternion(cs_record["rotation"]), inverse=False)
+ pose = ego2global.dot(sensor2ego)
+
+ return pose
+
+def obtain_sensor2top(nusc,
+ sensor_token,
+ l2e_t,
+ l2e_r_mat,
+ e2g_t,
+ e2g_r_mat,
+ sensor_type='lidar'):
+ """Obtain the info with RT matric from general sensor to Top LiDAR.
+
+ Args:
+ nusc (class): Dataset class in the nuScenes dataset.
+ sensor_token (str): Sample data token corresponding to the
+ specific sensor type.
+ l2e_t (np.ndarray): Translation from lidar to ego in shape (1, 3).
+ l2e_r_mat (np.ndarray): Rotation matrix from lidar to ego
+ in shape (3, 3).
+ e2g_t (np.ndarray): Translation from ego to global in shape (1, 3).
+ e2g_r_mat (np.ndarray): Rotation matrix from ego to global
+ in shape (3, 3).
+ sensor_type (str): Sensor to calibrate. Default: 'lidar'.
+
+ Returns:
+ sweep (dict): Sweep information after transformation.
+ """
+ sd_rec = nusc.get('sample_data', sensor_token)
+ cs_record = nusc.get('calibrated_sensor',
+ sd_rec['calibrated_sensor_token'])
+ pose_record = nusc.get('ego_pose', sd_rec['ego_pose_token'])
+ data_path = str(nusc.get_sample_data_path(sd_rec['token']))
+ if os.getcwd() in data_path: # path from lyftdataset is absolute path
+ data_path = data_path.split(f'{os.getcwd()}/')[-1] # relative path
+ sweep = {
+ 'data_path': data_path,
+ 'type': sensor_type,
+ 'sample_data_token': sd_rec['token'],
+ 'sensor2ego_translation': cs_record['translation'],
+ 'sensor2ego_rotation': cs_record['rotation'],
+ 'ego2global_translation': pose_record['translation'],
+ 'ego2global_rotation': pose_record['rotation'],
+ 'timestamp': sd_rec['timestamp']
+ }
+
+ l2e_r_s = sweep['sensor2ego_rotation']
+ l2e_t_s = sweep['sensor2ego_translation']
+ e2g_r_s = sweep['ego2global_rotation']
+ e2g_t_s = sweep['ego2global_translation']
+
+ # obtain the RT from sensor to Top LiDAR
+ # sweep->ego->global->ego'->lidar
+ l2e_r_s_mat = Quaternion(l2e_r_s).rotation_matrix
+ e2g_r_s_mat = Quaternion(e2g_r_s).rotation_matrix
+ R = (l2e_r_s_mat.T @ e2g_r_s_mat.T) @ (
+ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
+ T = (l2e_t_s @ e2g_r_s_mat.T + e2g_t_s) @ (
+ np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T)
+ T -= e2g_t @ (np.linalg.inv(e2g_r_mat).T @ np.linalg.inv(l2e_r_mat).T
+ ) + l2e_t @ np.linalg.inv(l2e_r_mat).T
+ sweep['sensor2lidar_rotation'] = R.T # points @ R.T + T
+ sweep['sensor2lidar_translation'] = T
+ return sweep
+
+def nuscenes_data_prep(root_path,
+ can_bus_root_path,
+ info_prefix,
+ version,
+ dataset_name,
+ out_dir,
+ max_sweeps=10):
+ """Prepare data related to nuScenes dataset.
+
+ Related data consists of '.pkl' files recording basic infos,
+ 2D annotations and groundtruth database.
+
+ Args:
+ root_path (str): Path of dataset root.
+ info_prefix (str): The prefix of info filenames.
+ version (str): Dataset version.
+ dataset_name (str): The dataset class name.
+ out_dir (str): Output directory of the groundtruth database info.
+ max_sweeps (int): Number of input consecutive frames. Default: 10
+ """
+ create_nuscenes_infos(
+ root_path, out_dir, can_bus_root_path, info_prefix, version=version, max_sweeps=max_sweeps)
+
+
+parser = argparse.ArgumentParser(description='Data converter arg parser')
+parser.add_argument('dataset', metavar='kitti', help='name of the dataset')
+parser.add_argument(
+ '--root-path',
+ type=str,
+ default='./data/kitti',
+ help='specify the root path of dataset')
+parser.add_argument(
+ '--canbus',
+ type=str,
+ default='./data',
+ help='specify the root path of nuScenes canbus')
+parser.add_argument(
+ '--version',
+ type=str,
+ default='v1.0',
+ required=False,
+ help='specify the dataset version, no need for kitti')
+parser.add_argument(
+ '--max-sweeps',
+ type=int,
+ default=10,
+ required=False,
+ help='specify sweeps of lidar per example')
+parser.add_argument(
+ '--out-dir',
+ type=str,
+ default='./data/kitti',
+ required='False',
+ help='name of info pkl')
+parser.add_argument('--extra-tag', type=str, default='kitti')
+parser.add_argument(
+ '--workers', type=int, default=4, help='number of threads to be used')
+args = parser.parse_args()
+
+if __name__ == '__main__':
+ if args.dataset == 'nuscenes' and args.version != 'v1.0-mini':
+ train_version = f'{args.version}-trainval'
+ nuscenes_data_prep(
+ root_path=args.root_path,
+ can_bus_root_path=args.canbus,
+ info_prefix=args.extra_tag,
+ version=train_version,
+ dataset_name='NuScenesDataset',
+ out_dir=args.out_dir,
+ max_sweeps=args.max_sweeps)
+ test_version = f'{args.version}-test'
+ nuscenes_data_prep(
+ root_path=args.root_path,
+ can_bus_root_path=args.canbus,
+ info_prefix=args.extra_tag,
+ version=test_version,
+ dataset_name='NuScenesDataset',
+ out_dir=args.out_dir,
+ max_sweeps=args.max_sweeps)
+ elif args.dataset == 'nuscenes' and args.version == 'v1.0-mini':
+ train_version = f'{args.version}'
+ nuscenes_data_prep(
+ root_path=args.root_path,
+ can_bus_root_path=args.canbus,
+ info_prefix=args.extra_tag,
+ version=train_version,
+ dataset_name='NuScenesDataset',
+ out_dir=args.out_dir,
+ max_sweeps=args.max_sweeps)
diff --git a/tools/data_converter/nuscenes_occlusion_converter.py b/tools/data_converter/nuscenes_occlusion_converter.py
new file mode 100644
index 0000000..9bbc4b4
--- /dev/null
+++ b/tools/data_converter/nuscenes_occlusion_converter.py
@@ -0,0 +1,1777 @@
+#!/usr/bin/env python3
+"""Fill annotation gaps in nuScenes info pkls and visualise the results.
+
+Two processing passes:
+
+1. **Gap interpolation** (CATR): fills frames between two consecutive
+ observations of the same instance using Constant Acceleration and Turn
+ Rate kinematics with a linear endpoint position correction.
+
+2. **Forward extrapolation** (CV): projects each track's last observed
+ state forward by up to ``--max-extrap-frames`` frames using Constant
+ Velocity.
+
+New flags in every sample info dict:
+ ``is_interpolated`` (bool array) — CATR gap-fill boxes
+ ``is_extrapolated`` (bool array) — CV tail-extrapolation boxes
+Directly observed boxes have both flags False.
+
+Sub-commands
+------------
+convert Run the annotation pipeline and save a new pkl.
+visualize Draw BEV track plots for example scenes from an existing pkl.
+
+Examples
+--------
+ python tools/data_converter/nuscenes_occlusion_converter.py convert \\
+ --input data/infos/nuscenes_infos_val.pkl \\
+ --output data/infos/nuscenes_infos_val_occ.pkl
+
+ python tools/data_converter/nuscenes_occlusion_converter.py visualize \\
+ --pkl data/infos/nuscenes_infos_val_occ.pkl \\
+ --num-scenes 6 --output-dir viz/
+"""
+
+import pickle
+import argparse
+import os
+import numpy as np
+from collections import defaultdict
+
+
+# ===========================================================================
+# CATR kinematics helpers
+# ===========================================================================
+
+def _normalize_angle(angle):
+ return (angle + np.pi) % (2.0 * np.pi) - np.pi
+
+
+def _integrate_catr(x0, y0, yaw0, v0, omega, a, t, clamp_v=False):
+ """Vectorised Riemann-sum integration of CATR kinematics from 0 to t.
+
+ clamp_v : if True, speed is clamped to [0, inf) so a decelerating object
+ stops rather than reversing (use for forward extrapolation).
+ """
+ n = max(200, int(abs(t) * 400))
+ dt = t / n
+ s = np.arange(n) * dt
+ theta = yaw0 + omega * s
+ v = v0 + a * s
+ if clamp_v:
+ v = np.maximum(v, 0.0)
+ x = x0 + float(np.sum(v * np.cos(theta))) * dt
+ y = y0 + float(np.sum(v * np.sin(theta))) * dt
+ return x, y
+
+
+def catr_interpolate(state0, state1, t, T):
+ """Interpolate between two kinematic states using the CATR model.
+
+ Derives constant turn rate (omega = dtheta/T) and linear acceleration
+ (a = dv/T) from the endpoint states, integrates to get intermediate
+ positions, then applies a linear endpoint correction so the trajectory
+ passes through both endpoints exactly.
+ """
+ x0, y0, z0, l, w, h, yaw0, vx0, vy0 = state0
+ x1, y1, z1, _, _, _, yaw1, vx1, vy1 = state1
+ alpha = t / T
+ v0 = 0.0 if np.isnan(vx0) or np.isnan(vy0) else float(np.hypot(vx0, vy0))
+ v1 = 0.0 if np.isnan(vx1) or np.isnan(vy1) else float(np.hypot(vx1, vy1))
+ omega = _normalize_angle(yaw1 - yaw0) / T
+ a = (v1 - v0) / T
+ x_t, y_t = _integrate_catr(x0, y0, yaw0, v0, omega, a, t)
+ x_T, y_T = _integrate_catr(x0, y0, yaw0, v0, omega, a, T)
+ x_interp = x_t + alpha * (x1 - x_T)
+ y_interp = y_t + alpha * (y1 - y_T)
+ z_interp = z0 + alpha * (z1 - z0)
+ yaw_t = _normalize_angle(yaw0 + omega * t)
+ v_t = max(0.0, v0 + a * t)
+ return (float(x_interp), float(y_interp), float(z_interp),
+ float(l), float(w), float(h), float(yaw_t),
+ float(v_t * np.cos(yaw_t)), float(v_t * np.sin(yaw_t)))
+
+
+# ===========================================================================
+# Coordinate transform helpers
+# ===========================================================================
+
+def _lidar2global_RT(info):
+ """Return (R 3×3, t 3) for the lidar → global transform of this sample."""
+ from pyquaternion import Quaternion
+ R_l2e = Quaternion(info['lidar2ego_rotation']).rotation_matrix
+ t_l2e = np.array(info['lidar2ego_translation'])
+ R_e2g = Quaternion(info['ego2global_rotation']).rotation_matrix
+ t_e2g = np.array(info['ego2global_translation'])
+ R = R_e2g @ R_l2e
+ t = R_e2g @ t_l2e + t_e2g
+ return R, t
+
+
+def _box_to_global(box7, vel2, info):
+ """Convert a lidar-frame box + lidar-frame velocity to a global-frame state.
+
+ Parameters
+ ----------
+ box7 : array-like (7,) [x, y, z, l, w, h, yaw] in lidar frame
+ vel2 : array-like (2,) [vx, vy] in lidar frame
+ (the nuScenes converter rotates box_velocity into lidar)
+ info : sample info dict
+
+ Returns
+ -------
+ 9-tuple (x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g) all in global frame
+ """
+ R, t = _lidar2global_RT(info)
+ p_g = R @ np.array([box7[0], box7[1], box7[2]]) + t
+ yaw_offset = np.arctan2(R[1, 0], R[0, 0])
+ yaw_g = _normalize_angle(float(box7[6]) + yaw_offset)
+ # Rotate velocity from lidar frame to global frame (translation has no
+ # effect on vectors, only the rotation part of R matters).
+ if np.isnan(vel2[0]) or np.isnan(vel2[1]):
+ vx_g, vy_g = 0.0, 0.0
+ else:
+ v_g = R @ np.array([float(vel2[0]), float(vel2[1]), 0.0])
+ vx_g, vy_g = float(v_g[0]), float(v_g[1])
+ return (float(p_g[0]), float(p_g[1]), float(p_g[2]),
+ float(box7[3]), float(box7[4]), float(box7[5]),
+ yaw_g, vx_g, vy_g)
+
+
+def _state_global_to_lidar(state_g, info):
+ """Convert a global-frame state tuple to the lidar frame of *info*.
+
+ Both position/yaw and velocity are rotated into the target lidar frame,
+ matching the gt_velocity convention used by the nuScenes converter.
+
+ Parameters
+ ----------
+ state_g : 9-tuple (x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g)
+ info : target sample info dict
+
+ Returns
+ -------
+ 9-tuple (x, y, z, l, w, h, yaw, vx, vy) — all in target lidar frame.
+ """
+ x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g = state_g
+ R, t = _lidar2global_RT(info)
+ p_l = R.T @ (np.array([x_g, y_g, z_g]) - t)
+ yaw_offset = np.arctan2(R[1, 0], R[0, 0])
+ yaw_l = _normalize_angle(yaw_g - yaw_offset)
+ # Rotate velocity from global frame back into the target lidar frame.
+ v_l = R.T @ np.array([vx_g, vy_g, 0.0])
+ return (float(p_l[0]), float(p_l[1]), float(p_l[2]),
+ float(l), float(w), float(h), float(yaw_l), float(v_l[0]), float(v_l[1]))
+
+
+def _global_xy_to_lidar_xy(positions_g, info):
+ """Transform (N,2) global XY positions into the lidar frame of *info*.
+
+ Uses the full 3-D lidar→global rotation (R.T) applied to the XY plane;
+ the Z component is ignored since we only need XY deltas.
+ """
+ R, t = _lidar2global_RT(info)
+ pos3 = np.zeros((len(positions_g), 3))
+ pos3[:, :2] = positions_g
+ return ((pos3 - t) @ R)[:, :2] # R.T @ (p - t) written as row-vec form
+
+
+# ===========================================================================
+# ML-prediction helpers
+# ===========================================================================
+
+def _load_ml_predictions(pred_path):
+ """Load a UniTraj inference NPZ and build an instance-token index.
+
+ Returns
+ -------
+ preds : ndarray (N, 6, 60, 2) agent-centric absolute XY offsets from
+ the agent's last-observed position, in a
+ frame whose +X axis aligns with the agent's
+ heading.
+ worlds : ndarray (N, 10) center_objects_world — [:2] global XY
+ origin, [6] global heading (yaw).
+ lookup : dict str → list[int] instance token → row indices in preds.
+ """
+ npz = np.load(pred_path, allow_pickle=True)
+ preds = npz['predictions'] # (N, 6, 60, 2)
+ meta = npz['metadata'].item() # flat dict
+ worlds = meta['center_objects_world'].astype(np.float64) # (N, 10)
+ ids = meta['center_objects_id'] # (N,) str
+ lookup = defaultdict(list)
+ for i, inst in enumerate(ids):
+ lookup[str(inst)].append(i)
+ return preds, worlds, lookup
+
+
+def _find_pred_index(inst_id, x_g, y_g, worlds, lookup, pos_tol=2.0):
+ """Return the row index in *worlds*/*preds* for this instance near (x_g, y_g).
+
+ Matches by instance token first, then picks the candidate whose stored
+ global XY is within *pos_tol* metres. Returns None if no match.
+ """
+ cands = lookup.get(str(inst_id), [])
+ if not cands:
+ return None
+ dists = [np.hypot(worlds[ci, 0] - x_g, worlds[ci, 1] - y_g) for ci in cands]
+ best = int(np.argmin(dists))
+ return cands[best] if dists[best] < pos_tol else None
+
+
+def _pred_global_xy(ml_preds_i, world, step):
+ """Convert agent-centric prediction at *step* to global (x, y).
+
+ ml_preds_i : (6, 60, 2) — mode 0 is used (probabilities not saved yet).
+ world : (10,) — world[:2] = global origin, world[6] = global heading.
+ step : 0-based UniTraj step (0 → 0.1 s ahead of prediction moment).
+ Returns (x_g, y_g) or None when step is out of [0, 59].
+ """
+ if not (0 <= step < ml_preds_i.shape[1]):
+ return None
+ pred_ac = ml_preds_i[0, step] # mode 0, shape (2,)
+ theta = float(world[6])
+ c, s = np.cos(theta), np.sin(theta)
+ return (float(c * pred_ac[0] - s * pred_ac[1] + world[0]),
+ float(s * pred_ac[0] + c * pred_ac[1] + world[1]))
+
+
+def _dt_to_step(dt_seconds, unitraj_dt):
+ """Seconds offset from prediction moment → 0-based UniTraj step index.
+
+ Step 0 corresponds to *unitraj_dt* seconds ahead.
+ Use unitraj_dt=0.1 for 10 Hz models, 0.5 for 2 Hz models.
+ """
+ return int(round(dt_seconds / unitraj_dt)) - 1
+
+
+def _build_instance_token_map(nuscenes_dataroot, version='v1.0-trainval'):
+ """Return a dict mapping integer nuScenes instance indices → token strings.
+
+ ``nusc.getind('instance', token)`` returns the 0-based position of a record
+ in the instance table, which is the same as its index in instance.json.
+ This map lets us convert the integer ``inst_ind`` stored by ForeSight back
+ to the 32-char token string stored in the UniTraj NPZ.
+ """
+ import json
+ path = os.path.join(nuscenes_dataroot, version, 'instance.json')
+ with open(path) as fh:
+ instances = json.load(fh)
+ return {i: rec['token'] for i, rec in enumerate(instances)}
+
+
+# ===========================================================================
+# Motion-model fitting helper
+# ===========================================================================
+
+def _fit_motion_model(appearances, infos, min_history,
+ ca_noise_thr, ca_max_thr, ca_consistency_thr,
+ omega_noise_thr, omega_max_thr, omega_consistency_thr):
+ """Fit scalar acceleration and turn-rate from the trailing consecutive
+ observation run of a track.
+
+ Only frames that are immediately consecutive in the scene (frame_pos
+ difference == 1) are used, starting from the last observation and
+ working backwards. Heading is estimated from velocity direction when
+ speed > 0.5 m/s (more reliable than box yaw for moving objects).
+
+ Returns
+ -------
+ (a_fit, omega_fit) :
+ a_fit — scalar acceleration (m/s²); 0.0 when not reliably fitted.
+ omega_fit — turn rate (rad/s); 0.0 when not reliably fitted.
+ """
+ # Build the longest trailing run of consecutive scene frames.
+ run = [appearances[-1]]
+ for i in range(len(appearances) - 2, -1, -1):
+ if appearances[i + 1][0] - appearances[i][0] == 1:
+ run.insert(0, appearances[i])
+ else:
+ break
+ if len(run) < min_history:
+ return 0.0, 0.0
+
+ # Collect (timestamp, speed, box yaw) per frame.
+ # Box yaw is directly annotated; velocity is inferred from position differences
+ # and is therefore noisier — use box yaw for turn-rate estimation.
+ states = []
+ for _, gi, bi in run:
+ info = infos[gi]
+ sg = _box_to_global(info['gt_boxes'][bi], info['gt_velocity'][bi], info)
+ spd = float(np.hypot(sg[7], sg[8]))
+ states.append((info['timestamp'] * 1e-6, spd, float(sg[6])))
+
+ # Per-interval acceleration and turn-rate.
+ accels, omegas = [], []
+ for i in range(len(states) - 1):
+ t0, v0, h0 = states[i]
+ t1, v1, h1 = states[i + 1]
+ dt = t1 - t0
+ if dt <= 0:
+ continue
+ accels.append((v1 - v0) / dt)
+ omegas.append(_normalize_angle(h1 - h0) / dt)
+
+ if not accels:
+ return 0.0, 0.0
+
+ a_mean = float(np.mean(accels))
+ a_std = float(np.std(accels)) if len(accels) > 1 else 0.0
+ om_mean = float(np.mean(omegas))
+ om_std = float(np.std(omegas)) if len(omegas) > 1 else 0.0
+
+ a_fit = (a_mean
+ if a_std < ca_consistency_thr
+ and ca_noise_thr <= abs(a_mean) <= ca_max_thr
+ else 0.0)
+ omega_fit = (om_mean
+ if om_std < omega_consistency_thr
+ and omega_noise_thr <= abs(om_mean) <= omega_max_thr
+ else 0.0)
+ return a_fit, omega_fit
+
+
+# ===========================================================================
+# Shared annotation-append helper
+# ===========================================================================
+
+def _append_annotation(info, x, y, z, l, w, h, yaw, vx, vy,
+ inst_ind, class_name, is_interp, is_extrap,
+ fut_trajs=None, fut_masks=None):
+ new_box = np.array([[x, y, z, l, w, h, yaw]], dtype=np.float32)
+ new_vel = np.array([[vx, vy]], dtype=np.float32)
+ fut_ts = info['gt_agent_fut_trajs'].shape[1]
+ if fut_trajs is None:
+ fut_trajs = np.zeros((1, fut_ts, 2), dtype=np.float32)
+ else:
+ fut_trajs = np.asarray(fut_trajs, dtype=np.float32).reshape(1, fut_ts, 2)
+ if fut_masks is None:
+ fut_masks = np.zeros((1, fut_ts), dtype=np.float32)
+ else:
+ fut_masks = np.asarray(fut_masks, dtype=np.float32).reshape(1, fut_ts)
+ info['gt_boxes'] = np.concatenate([info['gt_boxes'], new_box], axis=0)
+ info['gt_velocity'] = np.concatenate([info['gt_velocity'], new_vel], axis=0)
+ info['gt_agent_fut_trajs'] = np.concatenate([info['gt_agent_fut_trajs'], fut_trajs], axis=0)
+ info['gt_agent_fut_masks'] = np.concatenate([info['gt_agent_fut_masks'], fut_masks], axis=0)
+ info['gt_names'] = np.append(info['gt_names'], class_name)
+ info['valid_flag'] = np.append(info['valid_flag'], False)
+ info['num_lidar_pts'] = np.append(info['num_lidar_pts'], 0)
+ info['num_radar_pts'] = np.append(info['num_radar_pts'], 0)
+ info['is_interpolated'] = np.append(info['is_interpolated'], is_interp)
+ info['is_extrapolated'] = np.append(info['is_extrapolated'], is_extrap)
+ info['instance_inds'].append(inst_ind)
+
+
+# ===========================================================================
+# Convert sub-command
+# ===========================================================================
+
+def cmd_convert(args):
+ print(f'Loading {args.input} ...')
+ with open(args.input, 'rb') as f:
+ data = pickle.load(f)
+ infos = data['infos']
+ print(f' {len(infos)} samples loaded.')
+
+ for info in infos:
+ n = len(info['gt_boxes'])
+ info['is_interpolated'] = np.zeros(n, dtype=bool)
+ info['is_extrapolated'] = np.zeros(n, dtype=bool)
+
+ # Load ML predictions if provided
+ ml_preds = ml_worlds = ml_lookup = None
+ inst_token_map = {} # int index → 32-char token string
+ if getattr(args, 'predictions', None):
+ print(f'Loading ML predictions from {args.predictions} ...')
+ ml_preds, ml_worlds, ml_lookup = _load_ml_predictions(args.predictions)
+ print(f' {len(ml_preds)} prediction entries loaded.')
+ dataroot = getattr(args, 'nuscenes_dataroot', None)
+ if dataroot:
+ print(f'Building instance-token map from {dataroot} ...')
+ inst_token_map = _build_instance_token_map(dataroot)
+ print(f' {len(inst_token_map)} instance records indexed.')
+ else:
+ print('WARNING: --nuscenes-dataroot not set; ML predictions will '
+ 'not be matched (all extrapolations fall back to CV).')
+
+ scene_to_indices = defaultdict(list)
+ for gi, info in enumerate(infos):
+ scene_to_indices[info['scene_token']].append(gi)
+ for sc in scene_to_indices:
+ scene_to_indices[sc].sort(key=lambda i: infos[i]['timestamp'])
+
+ # ------------------------------------------------------------------
+ # Pass 1: gap interpolation (CATR)
+ # ------------------------------------------------------------------
+ total_interpolated = 0
+ for _, global_indices in scene_to_indices.items():
+ tracks = defaultdict(list)
+ for frame_pos, gi in enumerate(global_indices):
+ for bi, inst_ind in enumerate(infos[gi]['instance_inds']):
+ tracks[inst_ind].append((frame_pos, gi, bi))
+
+ for inst_ind, appearances in tracks.items():
+ for k in range(len(appearances) - 1):
+ fp0, gi0, bi0 = appearances[k]
+ fp1, gi1, bi1 = appearances[k + 1]
+ if fp1 - fp0 <= 1:
+ continue
+ info0, info1 = infos[gi0], infos[gi1]
+ # Convert both endpoints to global frame so CATR operates
+ # in a single consistent coordinate system.
+ state0_g = _box_to_global(
+ info0['gt_boxes'][bi0], info0['gt_velocity'][bi0], info0)
+ state1_g = _box_to_global(
+ info1['gt_boxes'][bi1], info1['gt_velocity'][bi1], info1)
+ class_name = info0['gt_names'][bi0]
+ t0_s = info0['timestamp'] * 1e-6
+ T = info1['timestamp'] * 1e-6 - t0_s
+ for gap_fp in range(fp0 + 1, fp1):
+ gap_gi = global_indices[gap_fp]
+ gap_info = infos[gap_gi]
+ t = gap_info['timestamp'] * 1e-6 - t0_s
+ result_g = catr_interpolate(state0_g, state1_g, t, T)
+ # Convert the global-frame result back to the target frame's
+ # lidar coordinates before storing.
+ x, y, z, l, w, h, yaw, vx, vy = _state_global_to_lidar(
+ result_g, gap_info)
+ # Build future trajectory: CATR positions for future scene
+ # frames that still fall within the known gap [t0_s, t0_s+T].
+ fut_ts_n = gap_info['gt_agent_fut_trajs'].shape[1]
+ fut_positions_g = [np.array([result_g[0], result_g[1]])]
+ for k in range(1, fut_ts_n + 1):
+ fut_fp = gap_fp + k
+ if fut_fp >= len(global_indices):
+ break
+ t_fut_rel = infos[global_indices[fut_fp]]['timestamp'] * 1e-6 - t0_s
+ if t_fut_rel > T:
+ break
+ fut_g = catr_interpolate(state0_g, state1_g, t_fut_rel, T)
+ fut_positions_g.append(np.array([fut_g[0], fut_g[1]]))
+ all_l = _global_xy_to_lidar_xy(np.array(fut_positions_g), gap_info)
+ fut_trajs = np.zeros((fut_ts_n, 2), dtype=np.float32)
+ fut_masks = np.zeros(fut_ts_n, dtype=np.float32)
+ for k in range(1, len(fut_positions_g)):
+ fut_trajs[k - 1] = all_l[k] - all_l[k - 1]
+ fut_masks[k - 1] = 1.0
+ _append_annotation(gap_info, x, y, z, l, w, h, yaw, vx, vy,
+ inst_ind, class_name,
+ is_interp=True, is_extrap=False,
+ fut_trajs=fut_trajs, fut_masks=fut_masks)
+ total_interpolated += 1
+
+ # ------------------------------------------------------------------
+ # Pass 2: forward extrapolation (ML prediction or CV fallback)
+ # ------------------------------------------------------------------
+ # NuScenes gt_names that map to MetaDrive VEHICLE / PEDESTRIAN / CYCLIST and
+ # are therefore included in UniTraj inference predictions.
+ # Source: scenarionet/converter/nuscenes/type.py
+ # VEHICLE_TYPE → MetaDriveType.VEHICLE → object_type 1
+ # HUMAN_TYPE → MetaDriveType.PEDESTRIAN → object_type 2
+ # BICYCLE_TYPE → MetaDriveType.CYCLIST → object_type 3
+ # Other categories (barrier, traffic_cone, movable_object.*, animal …)
+ # are not predicted — CV fallback is correct for those.
+ _ML_CLASSES = {
+ # VEHICLE_TYPE
+ 'car', 'truck', 'bus', 'trailer', 'construction_vehicle',
+ 'vehicle.emergency.ambulance', 'vehicle.emergency.police',
+ # HUMAN_TYPE
+ 'pedestrian',
+ 'human.pedestrian.stroller', 'human.pedestrian.personal_mobility',
+ 'human.pedestrian.construction_worker', 'human.pedestrian.police_officer',
+ # BICYCLE_TYPE
+ 'motorcycle', 'bicycle',
+ }
+ # Motorized road users eligible for CA/CTR motion-model fitting.
+ _MOTORIZED_CLASSES = {
+ 'car', 'truck', 'bus', 'trailer', 'construction_vehicle',
+ 'vehicle.emergency.ambulance', 'vehicle.emergency.police',
+ 'motorcycle',
+ }
+ # UniTraj VEHICLE type — stationary instances (total displacement < 2 m)
+ # are filtered out during training, so ML predictions for these are
+ # unreliable when the agent is near-stationary. PEDESTRIAN and CYCLIST
+ # have no displacement filter and keep their ML predictions when slow.
+ _VEHICLE_CLASSES = {
+ 'car', 'truck', 'bus', 'trailer', 'construction_vehicle',
+ 'vehicle.emergency.ambulance', 'vehicle.emergency.police',
+ }
+ total_extrap = 0
+ total_extrap_ml = 0
+ total_extrap_ml_shifted = 0 # ML matches via earlier-frame fallback
+ cv_by_class: dict = defaultdict(int) # CV fallback counts broken down by class
+ cv_model_counter = defaultdict(int) # fallback model breakdown (CP/CV/CA/CTR/CATR)
+ cv_no_entry: int = 0 # CV because instance has zero NPZ rows (unrecoverable)
+ cv_dist_fail: int = 0 # CV because closest NPZ row exceeds pos_tol (fixable?)
+ cv_sanity_fail: int = 0 # CV because ML 0.5 s point diverges too far from CV
+ if not args.no_extrapolate:
+ for _, global_indices in scene_to_indices.items():
+ n_frames = len(global_indices)
+ present = [set(infos[gi]['instance_inds']) for gi in global_indices]
+
+ orig_tracks = defaultdict(list)
+ for frame_pos, gi in enumerate(global_indices):
+ info = infos[gi]
+ for bi, inst_ind in enumerate(info['instance_inds']):
+ if not info['is_interpolated'][bi] and not info['is_extrapolated'][bi]:
+ orig_tracks[inst_ind].append((frame_pos, gi, bi))
+
+ for inst_ind, appearances in orig_tracks.items():
+ last_fp, last_gi, last_bi = appearances[-1]
+ if last_fp >= n_frames - 1:
+ continue
+ info_ref = infos[last_gi]
+ class_name = info_ref['gt_names'][last_bi]
+ # Convert last observation to global frame.
+ x_g, y_g, z_g, l, w, h, yaw_g, vx_g, vy_g = _box_to_global(
+ info_ref['gt_boxes'][last_bi],
+ info_ref['gt_velocity'][last_bi],
+ info_ref)
+ t_ref = info_ref['timestamp'] * 1e-6
+
+ # ML-prediction lookup: find the NPZ row whose stored global
+ # position is closest to (x_g, y_g) for this instance.
+ # ForeSight stores integer indices; the NPZ stores token strings
+ # — convert via inst_token_map before querying ml_lookup.
+ #
+ # Primary match: prediction moment == last observed frame.
+ # Fallback match: use an earlier observed frame whose position
+ # aligns with a sliding-window prediction moment. This lets
+ # late-scene instances (last observed at frames 140-195 in 10 Hz
+ # terms, beyond the sf=115 coverage) still receive ML-based
+ # extrapolation. pred_time_offset (seconds) is then added to
+ # every dt so that step indices remain relative to the NPZ
+ # prediction moment rather than the last observation.
+ pred_row = None
+ pred_time_offset = 0.0 # seconds from NPZ pred-moment to last_obs
+ if ml_lookup is not None and class_name in _ML_CLASSES:
+ inst_token = inst_token_map.get(inst_ind)
+ if inst_token is not None:
+ pred_row = _find_pred_index(
+ inst_token, x_g, y_g, ml_worlds, ml_lookup)
+
+ if pred_row is None and len(appearances) > 1:
+ # Try earlier observations, most-recent first.
+ for earlier_fp, earlier_gi, earlier_bi in reversed(appearances[:-1]):
+ info_e = infos[earlier_gi]
+ state_e = _box_to_global(
+ info_e['gt_boxes'][earlier_bi],
+ info_e['gt_velocity'][earlier_bi],
+ info_e)
+ row = _find_pred_index(
+ inst_token, state_e[0], state_e[1],
+ ml_worlds, ml_lookup)
+ if row is not None:
+ dt_off = t_ref - info_e['timestamp'] * 1e-6
+ # dt_off must be positive (earlier frame) and
+ # within the 6 s prediction horizon.
+ if 0 < dt_off < 6.0:
+ pred_row = row
+ pred_time_offset = dt_off
+ break
+
+ # Sanity-check: reject the ML prediction if its 0.5 s point
+ # diverges too far from the CV prediction at the same time.
+ # Threshold = 0.5 + alpha * speed (m) — tighter for slow movers.
+ if pred_row is not None:
+ step_check = _dt_to_step(pred_time_offset + 0.5, args.unitraj_dt)
+ if 0 <= step_check < 60:
+ xy_ml_check = _pred_global_xy(
+ ml_preds[pred_row], ml_worlds[pred_row], step_check)
+ if xy_ml_check is not None:
+ speed = float(np.hypot(vx_g, vy_g))
+ x_cv_check = x_g + vx_g * 0.5
+ y_cv_check = y_g + vy_g * 0.5
+ tol = 0.5 + 0.3 * speed
+ if np.hypot(xy_ml_check[0] - x_cv_check,
+ xy_ml_check[1] - y_cv_check) > tol:
+ pred_row = None
+ cv_sanity_fail += 1
+
+ # Diagnostic: classify why ML lookup failed for ML classes.
+ if pred_row is None and class_name in _ML_CLASSES and ml_lookup is not None:
+ inst_token_d = inst_token_map.get(inst_ind)
+ if inst_token_d is not None:
+ all_cands = ml_lookup.get(str(inst_token_d), [])
+ if all_cands:
+ cv_dist_fail += 1 # has NPZ entries but all exceeded pos_tol
+ else:
+ cv_no_entry += 1 # never appeared as center agent in inference
+
+ # Per-instance motion state (used by all fallback models below).
+ # Use box yaw directly — it is annotated by humans and more
+ # reliable than velocity-direction heading, which is inferred
+ # from position differences.
+ v0 = float(np.hypot(vx_g, vy_g))
+ drive_yaw_g = yaw_g
+
+ # Override ML for stationary vehicles: UniTraj filters out
+ # VEHICLE instances with total displacement < 2 m, so ML
+ # predictions for slow vehicles are unreliable. Pedestrians
+ # and cyclists have no displacement filter and keep their ML
+ # predictions even when near-stationary.
+ if (pred_row is not None
+ and v0 < args.cv_stationary_thr
+ and class_name in _VEHICLE_CLASSES):
+ pred_row = None
+
+ # Fit CA/CTR model from history for motorized classes without ML.
+ ca_accel = 0.0
+ ca_omega = 0.0
+ if pred_row is None and class_name in _MOTORIZED_CLASSES:
+ ca_accel, ca_omega = _fit_motion_model(
+ appearances, infos, args.min_ca_history,
+ args.ca_noise_thr, args.ca_max_thr, args.ca_consistency_thr,
+ args.omega_noise_thr, args.omega_max_thr,
+ args.omega_consistency_thr)
+
+ # Pre-compute ML handover state. Step 59 (the final UniTraj
+ # step) snaps close to zero for most predictions and is
+ # discarded. Steps 0–58 are reliable and used as-is.
+ # _ml_last : last ML step we use (= 58, skipping only 59)
+ _span = 5
+ _ml_last = 58
+ ml_end_x, ml_end_y = x_g, y_g
+ ml_end_vx, ml_end_vy = vx_g, vy_g
+ ml_end_yaw = drive_yaw_g
+ ml_end_dt_offset = 0.0
+ if pred_row is not None:
+ xy_end = _pred_global_xy(
+ ml_preds[pred_row], ml_worlds[pred_row], _ml_last)
+ if xy_end is not None:
+ ml_end_x, ml_end_y = xy_end
+ # Backward-difference velocity at _ml_last.
+ xy_bend = _pred_global_xy(
+ ml_preds[pred_row], ml_worlds[pred_row], _ml_last - _span)
+ if xy_bend is not None:
+ ml_end_vx = (ml_end_x - xy_bend[0]) / (_span * args.unitraj_dt)
+ ml_end_vy = (ml_end_y - xy_bend[1]) / (_span * args.unitraj_dt)
+ # Yaw: central difference centred on _ml_last.
+ xy_pend = _pred_global_xy(
+ ml_preds[pred_row], ml_worlds[pred_row], _ml_last - 2 * _span)
+ if xy_pend is not None:
+ dx_end = ml_end_x - xy_pend[0]
+ dy_end = ml_end_y - xy_pend[1]
+ if np.hypot(dx_end, dy_end) > 0.05:
+ ml_end_yaw = float(np.arctan2(dy_end, dx_end))
+ ml_end_dt_offset = (_ml_last + 1) * args.unitraj_dt - pred_time_offset
+
+ frames_added = 0
+ for fp in range(last_fp + 1, n_frames):
+ if inst_ind in present[fp] or frames_added >= args.max_extrap_frames:
+ break
+ gi = global_indices[fp]
+ dt = infos[gi]['timestamp'] * 1e-6 - t_ref
+
+ # --- Position, heading, velocity ---
+ # pred_time_offset shifts dt so steps are relative to the
+ # NPZ prediction moment (which may precede last_obs).
+ step = (_dt_to_step(pred_time_offset + dt, args.unitraj_dt)
+ if pred_row is not None else -1)
+ xy = (_pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step)
+ if (pred_row is not None and 0 <= step <= _ml_last) else None)
+
+ if xy is not None:
+ x_ep, y_ep = xy
+ # Heading: clamp to nearest step where central difference
+ # is valid (both neighbours in [0, 59]) so boundary steps
+ # borrow the adjacent central-diff yaw instead of using
+ # noisy one-sided differences.
+ _span = 5
+ step_cd = max(_span, min(step, 59 - _span))
+ xy_n_cd = _pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step_cd + _span)
+ xy_p_cd = _pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step_cd - _span)
+ if xy_n_cd is not None and xy_p_cd is not None:
+ dx, dy = xy_n_cd[0] - xy_p_cd[0], xy_n_cd[1] - xy_p_cd[1]
+ else:
+ dx, dy = 0.0, 0.0
+ yaw_ep = (float(np.arctan2(dy, dx))
+ if np.hypot(dx, dy) > 0.05 else yaw_g)
+ # Velocity: same clamped central difference as yaw —
+ # reuse xy_n_cd / xy_p_cd already fetched above.
+ if xy_n_cd is not None and xy_p_cd is not None:
+ vx_ep = (xy_n_cd[0] - xy_p_cd[0]) / (2 * _span * args.unitraj_dt)
+ vy_ep = (xy_n_cd[1] - xy_p_cd[1]) / (2 * _span * args.unitraj_dt)
+ else:
+ vx_ep, vy_ep = vx_g, vy_g
+ state_g = (x_ep, y_ep, z_g, l, w, h, yaw_ep, vx_ep, vy_ep)
+ total_extrap_ml += 1
+ if pred_time_offset > 0:
+ total_extrap_ml_shifted += 1
+ else:
+ # Fallback hierarchy: CP → CATR/CA/CTR → CV.
+ # When ML was available but its horizon is exhausted,
+ # anchor from the ML step-59 endpoint to avoid a
+ # positional jump back to the original observation.
+ cv_by_class[class_name] += 1
+ if pred_row is not None:
+ ref_x, ref_y = ml_end_x, ml_end_y
+ ref_vx, ref_vy = ml_end_vx, ml_end_vy
+ ref_yaw = ml_end_yaw
+ ref_v0 = float(np.hypot(ref_vx, ref_vy))
+ ref_dt = dt - ml_end_dt_offset
+ else:
+ ref_x, ref_y = x_g, y_g
+ ref_vx, ref_vy = vx_g, vy_g
+ ref_yaw = drive_yaw_g
+ ref_v0 = v0
+ ref_dt = dt
+ if ref_v0 < args.cv_stationary_thr:
+ model_key = 'CP'
+ state_g = (ref_x, ref_y, z_g, l, w, h, ref_yaw, 0.0, 0.0)
+ elif ca_accel != 0.0 or ca_omega != 0.0:
+ if ca_accel != 0.0 and ca_omega != 0.0:
+ model_key = 'CATR'
+ elif ca_accel != 0.0:
+ model_key = 'CA'
+ else:
+ model_key = 'CTR'
+ x_ca, y_ca = _integrate_catr(
+ ref_x, ref_y, ref_yaw, ref_v0, ca_omega, ca_accel, ref_dt,
+ clamp_v=True)
+ yaw_ca = _normalize_angle(ref_yaw + ca_omega * ref_dt)
+ v_ca = max(0.0, ref_v0 + ca_accel * ref_dt)
+ state_g = (x_ca, y_ca, z_g, l, w, h, yaw_ca,
+ v_ca * np.cos(yaw_ca), v_ca * np.sin(yaw_ca))
+ else:
+ model_key = 'CV'
+ state_g = (ref_x + ref_vx*ref_dt, ref_y + ref_vy*ref_dt, z_g,
+ l, w, h, ref_yaw, ref_vx, ref_vy)
+ cv_model_counter[model_key] += 1
+
+ x, y, z, l_, w_, h_, yaw, vx, vy = _state_global_to_lidar(
+ state_g, infos[gi])
+
+ # --- Future trajectory ---
+ fut_ts_n = infos[gi]['gt_agent_fut_trajs'].shape[1]
+ fut_positions_g = [np.array([state_g[0], state_g[1]])]
+ for k in range(1, fut_ts_n + 1):
+ fut_fp = fp + k
+ if fut_fp >= n_frames:
+ break
+ dt_fut = infos[global_indices[fut_fp]]['timestamp'] * 1e-6 - t_ref
+ step_fut = (_dt_to_step(pred_time_offset + dt_fut, args.unitraj_dt)
+ if pred_row is not None else -1)
+ xy_fut = (_pred_global_xy(ml_preds[pred_row], ml_worlds[pred_row], step_fut)
+ if (pred_row is not None and 0 <= step_fut <= _ml_last) else None)
+ if xy_fut is not None:
+ fut_positions_g.append(np.array([xy_fut[0], xy_fut[1]]))
+ else:
+ # Same ML-end anchoring as the main fallback above.
+ if pred_row is not None:
+ f_x, f_y = ml_end_x, ml_end_y
+ f_vx, f_vy = ml_end_vx, ml_end_vy
+ f_yaw = ml_end_yaw
+ f_v0 = float(np.hypot(f_vx, f_vy))
+ f_dt = dt_fut - ml_end_dt_offset
+ else:
+ f_x, f_y = x_g, y_g
+ f_vx, f_vy = vx_g, vy_g
+ f_yaw = drive_yaw_g
+ f_v0 = v0
+ f_dt = dt_fut
+ if f_v0 < args.cv_stationary_thr:
+ fut_positions_g.append(np.array([f_x, f_y]))
+ elif ca_accel != 0.0 or ca_omega != 0.0:
+ x_ca, y_ca = _integrate_catr(
+ f_x, f_y, f_yaw, f_v0, ca_omega, ca_accel, f_dt,
+ clamp_v=True)
+ fut_positions_g.append(np.array([x_ca, y_ca]))
+ else:
+ fut_positions_g.append(
+ np.array([f_x + f_vx*f_dt, f_y + f_vy*f_dt]))
+ all_l = _global_xy_to_lidar_xy(np.array(fut_positions_g), infos[gi])
+ fut_trajs = np.zeros((fut_ts_n, 2), dtype=np.float32)
+ fut_masks = np.zeros(fut_ts_n, dtype=np.float32)
+ for k in range(1, len(fut_positions_g)):
+ fut_trajs[k - 1] = all_l[k] - all_l[k - 1]
+ fut_masks[k - 1] = 1.0
+ _append_annotation(infos[gi], x, y, z, l_, w_, h_, yaw,
+ vx, vy, inst_ind, class_name,
+ is_interp=False, is_extrap=True,
+ fut_trajs=fut_trajs, fut_masks=fut_masks)
+ present[fp].add(inst_ind)
+ frames_added += 1
+ total_extrap += 1
+
+ n_orig = sum(int(np.sum(~i['is_interpolated'] & ~i['is_extrapolated'])) for i in infos)
+ n_interp = sum(int(np.sum( i['is_interpolated'])) for i in infos)
+ n_extrap = sum(int(np.sum( i['is_extrapolated'])) for i in infos)
+ n_cv_non_ml = sum(v for k, v in cv_by_class.items() if k not in _ML_CLASSES)
+ n_cv_ml_cls = sum(v for k, v in cv_by_class.items() if k in _ML_CLASSES)
+ print(f'Original annotations : {n_orig}')
+ print(f'Interpolated (CATR) : {n_interp}')
+ print(f'Extrapolated (ML exact) : {total_extrap_ml - total_extrap_ml_shifted}')
+ print(f'Extrapolated (ML shifted) : {total_extrap_ml_shifted}')
+ print(f'Extrapolated (CV fallback) : {total_extrap - total_extrap_ml}')
+ print(f' of which non-ML classes : {n_cv_non_ml} '
+ f'(traffic_cone/barrier/… — CV is correct here)')
+ print(f' of which ML classes : {n_cv_ml_cls} '
+ f'(temporal gap or no NPZ entry)')
+ print(f' no NPZ entry (unrecoverable) : {cv_no_entry}')
+ print(f' pos_tol exceeded (fixable?) : {cv_dist_fail}')
+ print(f' sanity filter (0.5s/speed) : {cv_sanity_fail}')
+ if cv_model_counter:
+ mdl = sorted(cv_model_counter.items())
+ print(f' fallback model breakdown : ' +
+ ' '.join(f'{k}:{v}' for k, v in mdl))
+ if cv_by_class:
+ by_cls = sorted(cv_by_class.items(), key=lambda x: -x[1])
+ print(f' CV breakdown by class : ' +
+ ' '.join(f'{k}:{v}' for k, v in by_cls))
+ print(f'Total annotations : {n_orig + n_interp + n_extrap}')
+
+ # ------------------------------------------------------------------
+ # Ego-distance filtering (two-level):
+ # Level 2 — drop instances with no original obs within --max-dist.
+ # Level 1 — tail-trim: drop all frames of an instance AFTER the last
+ # frame where it appears within --max-dist. Using a tail-trim
+ # rather than an independent per-frame check ensures trajectories
+ # stay continuous; a per-frame check would punch holes in curved
+ # tracks that briefly exit and re-enter the range boundary,
+ # producing disconnected fragments 50+ m away.
+ # ------------------------------------------------------------------
+ if args.max_dist > 0:
+ # Level-2: instances that have at least one original obs within range.
+ valid_insts = set()
+ for info in infos:
+ boxes = info['gt_boxes']
+ is_i = info['is_interpolated']
+ is_e = info['is_extrapolated']
+ for bi, inst_ind in enumerate(info['instance_inds']):
+ if not is_i[bi] and not is_e[bi]:
+ if np.hypot(boxes[bi, 0], boxes[bi, 1]) <= args.max_dist:
+ valid_insts.add(inst_ind)
+ print(f' Ego-distance filter: {len(valid_insts):,} instances have '
+ f'≥1 original obs within {args.max_dist} m.')
+
+ # Level-1 pre-pass: for each valid instance record the timestamp of
+ # its last frame that is within max_dist. Timestamps are
+ # monotonically increasing within a scene, so all frames up to and
+ # including this timestamp are kept; frames after it are trimmed.
+ last_ts_within: dict = {}
+ for info in infos:
+ ts = info['timestamp']
+ boxes = info['gt_boxes']
+ for bi, inst_ind in enumerate(info['instance_inds']):
+ if inst_ind not in valid_insts:
+ continue
+ if np.hypot(boxes[bi, 0], boxes[bi, 1]) <= args.max_dist:
+ if ts > last_ts_within.get(inst_ind, -1):
+ last_ts_within[inst_ind] = ts
+
+ # Level-1 apply: keep box if instance is valid AND the frame
+ # timestamp is at or before the last within-range timestamp.
+ n_before = sum(len(i['instance_inds']) for i in infos)
+ for info in infos:
+ ts = info['timestamp']
+ boxes = info['gt_boxes']
+ keep = np.array([
+ (inst_ind in valid_insts and
+ ts <= last_ts_within.get(inst_ind, -1))
+ for bi, inst_ind in enumerate(info['instance_inds'])
+ ], dtype=bool)
+ info['gt_boxes'] = info['gt_boxes'][keep]
+ info['gt_velocity'] = info['gt_velocity'][keep]
+ info['gt_names'] = info['gt_names'][keep]
+ info['valid_flag'] = info['valid_flag'][keep]
+ info['num_lidar_pts'] = info['num_lidar_pts'][keep]
+ info['num_radar_pts'] = info['num_radar_pts'][keep]
+ info['is_interpolated'] = info['is_interpolated'][keep]
+ info['is_extrapolated'] = info['is_extrapolated'][keep]
+ info['gt_agent_fut_trajs'] = info['gt_agent_fut_trajs'][keep]
+ info['gt_agent_fut_masks'] = info['gt_agent_fut_masks'][keep]
+ info['instance_inds'] = [
+ inst for inst, k in zip(info['instance_inds'], keep) if k]
+ n_after = sum(len(i['instance_inds']) for i in infos)
+ print(f' Removed {n_before - n_after:,} box entries '
+ f'({n_before:,} → {n_after:,}).')
+
+ print(f'Saving to {args.output} ...')
+ data['infos'] = infos
+ with open(args.output, 'wb') as f:
+ pickle.dump(data, f)
+ print('Done.')
+
+
+# ===========================================================================
+# Visualize sub-command
+# ===========================================================================
+
+_C_OBS = '#1565c0' # observed, valid_flag=True (blue)
+_C_INVALID = '#6a1b9a' # observed, valid_flag=False (purple)
+_C_INTERP = '#e65100' # interpolated (CATR) (deep orange)
+_C_EXTRAP = '#b71c1c' # extrapolated (dark red)
+_C_EGO = '#2e7d32' # ego vehicle (dark green)
+_C_BG = '#ffffff' # plot background (white)
+
+
+def _ego_pose_global(info):
+ from pyquaternion import Quaternion
+ q = Quaternion(info['ego2global_rotation'])
+ pos = np.array(info['ego2global_translation'])[:2]
+ yaw = np.arctan2(2*(q.w*q.z + q.x*q.y), 1 - 2*(q.y**2 + q.z**2))
+ return pos, yaw
+
+
+def _draw_box(ax, cx, cy, length, width, yaw, color,
+ alpha_face=0.22, alpha_edge=0.85, lw=0.8, zorder=3):
+ import matplotlib.pyplot as plt
+ c, s = np.cos(yaw), np.sin(yaw)
+ R2 = np.array([[c, -s], [s, c]])
+ half = np.array([[ length/2, width/2],
+ [-length/2, width/2],
+ [-length/2, -width/2],
+ [ length/2, -width/2]])
+ corners = half @ R2.T + np.array([cx, cy])
+ ax.add_patch(plt.Polygon(corners, closed=True, facecolor=color, edgecolor=color,
+ alpha=alpha_face, linewidth=lw, zorder=zorder))
+ ax.add_patch(plt.Polygon(corners, closed=True, facecolor='none', edgecolor=color,
+ alpha=alpha_edge, linewidth=lw, zorder=zorder+1))
+
+
+def _visualize_scene(infos, gidxs, ax, title='', show_forecast=False, nusc=None):
+ import matplotlib.patches as mpatches
+ import matplotlib.lines as mlines
+
+ inst_data = defaultdict(list)
+ ego_pos, ego_yaws = [], []
+
+ # Reference frame: first ego pose in the scene.
+ # All global-frame coordinates are transformed so that the first ego
+ # position is the origin and the first ego heading points along +X.
+ p0, yaw0 = _ego_pose_global(infos[gidxs[0]])
+ c0, s0 = np.cos(yaw0), np.sin(yaw0)
+ R_inv = np.array([[c0, s0], [-s0, c0]]) # global → first-ego-frame rotation
+
+ def to_ego(xy_global):
+ """Transform (N,2) or (2,) from global frame to first-ego frame."""
+ return (np.asarray(xy_global) - p0) @ R_inv.T
+
+ # Map background — drawn first so trajectories render on top.
+ # nusc.explorer.render_ego_centric_map uses flat vehicle coordinates for
+ # the reference sample, which matches our first-ego-frame (+X forward,
+ # +Y left) exactly. We pass a generous axes_limit so the full scene
+ # extent is covered; explicit axis limits are set later from data.
+ if nusc is not None:
+ try:
+ first_token = infos[gidxs[0]]['token']
+ lidar_token = nusc.get('sample', first_token)['data']['LIDAR_TOP']
+ nusc.explorer.render_ego_centric_map(
+ sample_data_token=lidar_token, axes_limit=200, ax=ax)
+ except Exception:
+ pass # map unavailable — continue without it
+
+ for frame_idx, gi in enumerate(gidxs):
+ info = infos[gi]
+ R, t = _lidar2global_RT(info)
+ boxes = info['gt_boxes']
+ is_i = info['is_interpolated']
+ is_e = info['is_extrapolated']
+
+ if len(boxes):
+ xy_g = (boxes[:, :3] @ R.T + t)[:, :2]
+ yaw_g = boxes[:, 6] + np.arctan2(R[1, 0], R[0, 0])
+ xy_e = to_ego(xy_g)
+ yaw_e = yaw_g - yaw0
+ valid = info['valid_flag']
+ for bi in range(len(boxes)):
+ if is_i[bi]:
+ color = _C_INTERP
+ elif is_e[bi]:
+ color = _C_EXTRAP
+ elif valid[bi]:
+ color = _C_OBS
+ else:
+ color = _C_INVALID
+
+ # Reconstruct future waypoints in first-ego frame.
+ # Deltas are in the current sample's lidar frame; accumulate
+ # them to get absolute lidar positions, then transform to global
+ # and finally to first-ego frame.
+ fut_traj_raw = info['gt_agent_fut_trajs'][bi] # (T, 2) lidar deltas
+ fut_mask_raw = info['gt_agent_fut_masks'][bi] # (T,)
+ pos_l = np.array([boxes[bi, 0], boxes[bi, 1]], dtype=np.float64)
+ fut_pts_l = []
+ for k in range(len(fut_traj_raw)):
+ if fut_mask_raw[k] < 0.5:
+ break
+ pos_l = pos_l + fut_traj_raw[k]
+ fut_pts_l.append(pos_l.copy())
+ if fut_pts_l:
+ fpl = np.zeros((len(fut_pts_l), 3))
+ fpl[:, :2] = np.array(fut_pts_l)
+ fut_pts_ego = to_ego((fpl @ R.T + t)[:, :2])
+ else:
+ fut_pts_ego = np.zeros((0, 2))
+
+ inst_data[info['instance_inds'][bi]].append((
+ frame_idx,
+ float(xy_e[bi, 0]), float(xy_e[bi, 1]),
+ float(boxes[bi, 3]), float(boxes[bi, 4]),
+ float(yaw_e[bi]), color,
+ fut_pts_ego, # index 7: (K,2) future waypoints in ego frame
+ ))
+
+ pos, yaw = _ego_pose_global(info)
+ ego_pos.append(to_ego(pos))
+ ego_yaws.append(yaw - yaw0)
+
+ ego_pos = np.array(ego_pos)
+
+ # Track lines — drawn for every consecutive pair in chronological order.
+ # Segments bridging a frame gap (diff > 1) are dashed to indicate the
+ # discontinuity; all other segments are solid.
+ for frames in inst_data.values():
+ frames = sorted(frames, key=lambda f: f[0])
+ for k in range(len(frames) - 1):
+ f0, f1 = frames[k], frames[k + 1]
+ gap = f1[0] - f0[0]
+ ax.plot([f0[1], f1[1]], [f0[2], f1[2]],
+ color=f0[6], linewidth=0.7,
+ alpha=0.35 if gap > 1 else 0.55,
+ linestyle='--' if gap > 1 else '-',
+ zorder=2, solid_capstyle='round')
+
+ # Boxes — draw back-to-front so valid observed renders on top
+ _alpha_face = {_C_EXTRAP: 0.13, _C_INTERP: 0.22, _C_INVALID: 0.18, _C_OBS: 0.22}
+ _alpha_edge = {_C_EXTRAP: 0.55, _C_INTERP: 0.85, _C_INVALID: 0.70, _C_OBS: 0.85}
+ _lw = {_C_EXTRAP: 0.5, _C_INTERP: 0.8, _C_INVALID: 0.7, _C_OBS: 0.8}
+ for target_color in [_C_EXTRAP, _C_INTERP, _C_INVALID, _C_OBS]:
+ for frames in inst_data.values():
+ for f in frames:
+ if f[6] != target_color:
+ continue
+ _draw_box(ax, f[1], f[2], f[3], f[4], f[5], color=f[6],
+ alpha_face=_alpha_face[f[6]],
+ alpha_edge=_alpha_edge[f[6]],
+ lw=_lw[f[6]])
+
+ # Forecast trajectories (one dotted line per box appearance with valid future steps)
+ if show_forecast:
+ for frames in inst_data.values():
+ for f in frames:
+ fut_pts = f[7]
+ if len(fut_pts) == 0:
+ continue
+ all_pts = np.vstack([[[f[1], f[2]]], fut_pts])
+ ax.plot(all_pts[:, 0], all_pts[:, 1],
+ color=f[6], alpha=0.55, linewidth=1.0,
+ linestyle=':', zorder=4, solid_capstyle='round')
+ ax.scatter(fut_pts[:, 0], fut_pts[:, 1],
+ c=f[6], s=5, alpha=0.65, zorder=5, linewidths=0)
+
+ # Ego trajectory and boxes
+ ax.plot(ego_pos[:, 0], ego_pos[:, 1],
+ color='#333333', linewidth=1.5, linestyle='--', zorder=6, alpha=0.8)
+ ax.scatter(*ego_pos[0], color='#333333', s=30, zorder=8, marker='o')
+ ax.scatter(*ego_pos[-1], color='#333333', s=30, zorder=8, marker='x')
+ for pos, yaw in zip(ego_pos, ego_yaws):
+ _draw_box(ax, pos[0], pos[1], 4.08, 1.73, yaw,
+ color=_C_EGO, alpha_face=0.30, alpha_edge=0.9, lw=1.0, zorder=7)
+
+ # Explicit axis limits from data so the map imshow (which calls set_xlim/
+ # set_ylim internally) does not dictate the final view.
+ all_x = ([f[1] for v in inst_data.values() for f in v]
+ + list(ego_pos[:, 0]))
+ all_y = ([f[2] for v in inst_data.values() for f in v]
+ + list(ego_pos[:, 1]))
+ if all_x:
+ xspan = max(all_x) - min(all_x)
+ yspan = max(all_y) - min(all_y)
+ span = max(xspan, yspan, 20.0)
+ pad = span * 0.08
+ cx = (max(all_x) + min(all_x)) / 2
+ cy = (max(all_y) + min(all_y)) / 2
+ ax.set_xlim(cx - span / 2 - pad, cx + span / 2 + pad)
+ ax.set_ylim(cy - span / 2 - pad, cy + span / 2 + pad)
+ ax.set_aspect('equal')
+ ax.set_facecolor(_C_BG)
+ ax.grid(True, color='#bbbbbb', alpha=0.4, linewidth=0.5)
+ ax.tick_params(colors='#444444', labelsize=7)
+ for spine in ax.spines.values():
+ spine.set_color('#cccccc')
+ ax.set_xlabel('X (m)', color='#444444', fontsize=8)
+ ax.set_ylabel('Y (m)', color='#444444', fontsize=8)
+ n_valid = sum(sum(1 for f in v if f[6] == _C_OBS) for v in inst_data.values())
+ n_invalid = sum(sum(1 for f in v if f[6] == _C_INVALID) for v in inst_data.values())
+ n_interp = sum(sum(1 for f in v if f[6] == _C_INTERP) for v in inst_data.values())
+ n_extrap = sum(sum(1 for f in v if f[6] == _C_EXTRAP) for v in inst_data.values())
+ ax.set_title(
+ f'{title} | Visible {n_valid} Occluded {n_invalid} '
+ f'Interpolated {n_interp} Extrapolated {n_extrap}',
+ color='#111111', fontsize=8, pad=4)
+
+ leg_handles = [
+ mpatches.Patch(color='#aaaaaa', label='Driveable Area'),
+ mpatches.Patch(color=_C_OBS, label='Visible'),
+ mpatches.Patch(color=_C_INVALID, label='Occluded'),
+ mpatches.Patch(color=_C_INTERP, label='Interpolated'),
+ mpatches.Patch(color=_C_EXTRAP, label='Extrapolated'),
+ mpatches.Patch(color=_C_EGO, label='Ego'),
+ ]
+ if show_forecast:
+ leg_handles.append(mlines.Line2D(
+ [], [], color='#333333', linestyle=':', linewidth=1.2, alpha=0.7,
+ label='Forecast'))
+ ax.legend(handles=leg_handles, loc='upper right', framealpha=0.85,
+ fontsize=6, labelcolor='#111111', facecolor='white',
+ edgecolor='#cccccc', ncol=1)
+
+
+# ===========================================================================
+# Single-track helpers (individual track plots)
+# ===========================================================================
+
+def _collect_all_tracks(infos, all_scenes):
+ """Return a list of per-instance track dicts across all scenes.
+
+ Uses the same first-ego-frame coordinate system as ``_visualize_scene``:
+ each scene is centred on its first ego pose so coordinates are directly
+ comparable to the scene-level BEV plots.
+
+ Each frame tuple is ``(frame_idx, x_e, y_e, l, w, yaw_e, color)`` —
+ identical to the entries stored in ``inst_data`` inside ``_visualize_scene``.
+ ``cv_frames`` is a parallel list of ``(x_e, y_e)`` CV predictions for each
+ extrapolated frame, used by the mismatch plot.
+ """
+ tracks = []
+ for gidxs in all_scenes:
+ sc_tok = infos[gidxs[0]]['scene_token']
+
+ # Same reference frame as _visualize_scene: first ego pose of the scene.
+ p0, yaw0 = _ego_pose_global(infos[gidxs[0]])
+ c0, s0 = np.cos(yaw0), np.sin(yaw0)
+ R_ego = np.array([[c0, s0], [-s0, c0]]) # global → first-ego rotation
+
+ def to_ego(xy):
+ return (np.asarray(xy) - p0) @ R_ego.T
+
+ # inst_frames: same format as inst_data in _visualize_scene.
+ # inst_meta: raw (gi, bi) lists for CV anchor computation.
+ inst_frames = defaultdict(list)
+ inst_meta = defaultdict(lambda: {'obs_raw': [], 'extrap_raw': []})
+
+ for frame_idx, gi in enumerate(gidxs):
+ info = infos[gi]
+ R, t = _lidar2global_RT(info)
+ boxes = info['gt_boxes']
+ if not len(boxes):
+ continue
+ xy_g = (boxes[:, :3] @ R.T + t)[:, :2] # lidar → global
+ yaw_g = boxes[:, 6] + np.arctan2(R[1, 0], R[0, 0])
+ xy_e = to_ego(xy_g) # global → first-ego
+ yaw_e = yaw_g - yaw0
+ is_i = info['is_interpolated']
+ is_e = info['is_extrapolated']
+ valid = info['valid_flag']
+
+ for bi, inst_ind in enumerate(info['instance_inds']):
+ if is_i[bi]:
+ color = _C_INTERP
+ elif is_e[bi]:
+ color = _C_EXTRAP
+ inst_meta[inst_ind]['extrap_raw'].append((gi, bi))
+ elif valid[bi]:
+ color = _C_OBS
+ inst_meta[inst_ind]['obs_raw'].append((gi, bi))
+ else:
+ color = _C_INVALID
+ inst_meta[inst_ind]['obs_raw'].append((gi, bi))
+
+ inst_frames[inst_ind].append((
+ frame_idx,
+ float(xy_e[bi, 0]), float(xy_e[bi, 1]),
+ float(boxes[bi, 3]), float(boxes[bi, 4]),
+ float(yaw_e[bi]), color,
+ ))
+
+ for inst_ind, raw_frames in inst_frames.items():
+ frames_sorted = sorted(raw_frames, key=lambda f: f[0])
+ n_interp = sum(1 for f in frames_sorted if f[6] == _C_INTERP)
+ n_extrap = sum(1 for f in frames_sorted if f[6] == _C_EXTRAP)
+
+ def _path_len(color):
+ pts = [(f[1], f[2]) for f in frames_sorted if f[6] == color]
+ return sum(
+ np.hypot(pts[i+1][0] - pts[i][0], pts[i+1][1] - pts[i][1])
+ for i in range(len(pts) - 1)
+ )
+ extrap_dist = _path_len(_C_EXTRAP)
+ interp_dist = _path_len(_C_INTERP)
+
+ # CV mismatch: project last-observed velocity forward in ego frame.
+ # Rigid transform preserves distances, so scores are identical to
+ # computing in global frame.
+ max_cv_ml_diff = 0.0
+ cv_frames = [] # (x_e, y_e) per extrapolated frame
+ meta = inst_meta[inst_ind]
+ if meta['obs_raw'] and meta['extrap_raw']:
+ # Use the last observed frame BEFORE the extrapolation starts.
+ # obs_raw may contain re-appearance frames that come after the
+ # extrap gap; using those as the anchor gives negative dt and
+ # a completely wrong CV projection.
+ first_extrap_gi = meta['extrap_raw'][0][0]
+ obs_before = [(g, b) for g, b in meta['obs_raw']
+ if g < first_extrap_gi]
+ if not obs_before:
+ obs_before = meta['obs_raw']
+ anc_gi, anc_bi = obs_before[-1]
+ anc_info = infos[anc_gi]
+ state_g = _box_to_global(
+ anc_info['gt_boxes'][anc_bi],
+ anc_info['gt_velocity'][anc_bi],
+ anc_info)
+ xy0_e = to_ego(np.array([state_g[0], state_g[1]]))
+ x0_e, y0_e = float(xy0_e[0]), float(xy0_e[1])
+ v_e = R_ego @ np.array([state_g[7], state_g[8]])
+ vx_e, vy_e = float(v_e[0]), float(v_e[1])
+ t0 = anc_info['timestamp'] * 1e-6
+
+ for gi_e, bi_e in meta['extrap_raw']:
+ dt = infos[gi_e]['timestamp'] * 1e-6 - t0
+ cv_frames.append((x0_e + vx_e * dt, y0_e + vy_e * dt))
+
+ ext_e = [(f[1], f[2]) for f in frames_sorted if f[6] == _C_EXTRAP]
+ if cv_frames and len(cv_frames) == len(ext_e):
+ diffs = [np.hypot(e[0] - c[0], e[1] - c[1])
+ for e, c in zip(ext_e, cv_frames)]
+ max_cv_ml_diff = float(max(diffs))
+
+ # Class name from last observed frame (or first extrap/interp).
+ class_name = ''
+ for src, idx in [(meta['obs_raw'], -1), (meta['extrap_raw'], 0)]:
+ if src:
+ gi_c, bi_c = src[idx]
+ class_name = infos[gi_c]['gt_names'][bi_c]
+ break
+
+ tracks.append({
+ 'inst_ind': inst_ind,
+ 'scene_token': sc_tok,
+ 'first_sample_token': infos[gidxs[0]]['token'],
+ 'class_name': class_name,
+ 'frames': frames_sorted, # (frame_idx, x_e, y_e, l, w, yaw_e, color)
+ 'cv_frames': cv_frames, # (x_e, y_e) per extrap frame
+ 'n_interp': n_interp,
+ 'n_extrap': n_extrap,
+ 'extrap_dist': extrap_dist,
+ 'interp_dist': interp_dist,
+ 'max_cv_ml_diff': max_cv_ml_diff,
+ })
+ return tracks
+
+
+def _visualize_single_track_ax(track, ax, show_cv=False, nusc=None):
+ """Render one instance track using the same style as ``_visualize_scene``.
+
+ Draws track lines between adjacent frames and oriented boxes at every
+ frame position. The view is auto-zoomed to the track's spatial extent.
+ Coordinates are already in first-ego frame (set by ``_collect_all_tracks``).
+
+ If *nusc* is provided, a map background is rendered first using the scene's
+ first-frame LIDAR_TOP token. The auto-zoom axis limits are always applied
+ last so the view is unchanged.
+ """
+ import matplotlib.patches as mpatches
+ import matplotlib.lines as mlines
+
+ frames = track['frames'] # (frame_idx, x_e, y_e, l, w, yaw_e, color)
+ cv_frames = track['cv_frames']
+
+ if not frames:
+ return
+
+ # --- Compute auto-zoom bounds first (needed for map axes_limit) ----------
+ all_xy = [(f[1], f[2]) for f in frames]
+ if show_cv:
+ all_xy.extend(cv_frames)
+ xs = [p[0] for p in all_xy]
+ ys = [p[1] for p in all_xy]
+ span = max(max(xs) - min(xs), max(ys) - min(ys), 8.0)
+ pad = span * 0.20
+ cx = (max(xs) + min(xs)) / 2
+ cy = (max(ys) + min(ys)) / 2
+ xlim = (cx - span / 2 - pad, cx + span / 2 + pad)
+ ylim = (cy - span / 2 - pad, cy + span / 2 + pad)
+
+ # --- Map background (rendered before track elements) ---------------------
+ ax.set_facecolor(_C_BG)
+ if nusc is not None:
+ try:
+ first_token = track.get('first_sample_token')
+ if first_token:
+ lidar_token = nusc.get('sample', first_token)['data']['LIDAR_TOP']
+ # axes_limit must reach the farthest corner of the track from
+ # the scene origin (0, 0) so the full map tile is loaded.
+ map_limit = max(
+ abs(cx) + span / 2 + pad,
+ abs(cy) + span / 2 + pad,
+ ) + 20.0
+ nusc.explorer.render_ego_centric_map(
+ sample_data_token=lidar_token,
+ axes_limit=map_limit,
+ ax=ax)
+ except Exception:
+ pass # map unavailable — continue without it
+
+ # --- Track lines (all consecutive pairs; dashed across frame gaps) -------
+ for k in range(len(frames) - 1):
+ f0, f1 = frames[k], frames[k + 1]
+ gap = f1[0] - f0[0]
+ ax.plot([f0[1], f1[1]], [f0[2], f1[2]],
+ color=f0[6], lw=0.7,
+ alpha=0.35 if gap > 1 else 0.55,
+ linestyle='--' if gap > 1 else '-',
+ zorder=2, solid_capstyle='round')
+
+ # --- Boxes back-to-front (matching _visualize_scene render order) --------
+ _alpha_face = {_C_EXTRAP: 0.13, _C_INTERP: 0.22, _C_INVALID: 0.18, _C_OBS: 0.22}
+ _alpha_edge = {_C_EXTRAP: 0.55, _C_INTERP: 0.85, _C_INVALID: 0.70, _C_OBS: 0.85}
+ _lw_d = {_C_EXTRAP: 0.5, _C_INTERP: 0.8, _C_INVALID: 0.7, _C_OBS: 0.8}
+ for target_color in [_C_EXTRAP, _C_INTERP, _C_INVALID, _C_OBS]:
+ for f in frames:
+ if f[6] != target_color:
+ continue
+ _draw_box(ax, f[1], f[2], f[3], f[4], f[5], color=f[6],
+ alpha_face=_alpha_face[f[6]],
+ alpha_edge=_alpha_edge[f[6]],
+ lw=_lw_d[f[6]])
+
+ # --- CV baseline ---------------------------------------------------------
+ if show_cv and len(cv_frames) >= 2:
+ cv_xs = [p[0] for p in cv_frames]
+ cv_ys = [p[1] for p in cv_frames]
+ ax.plot(cv_xs, cv_ys, color='#ffd54f', lw=1.0, ls='--',
+ marker='.', ms=2, alpha=0.75, zorder=3)
+
+ # Star at the obs→extrap/interp handover point.
+ obs_frames = [f for f in frames if f[6] in (_C_OBS, _C_INVALID)]
+ if obs_frames:
+ lf = obs_frames[-1]
+ ax.scatter(lf[1], lf[2], color='#333333', s=22, zorder=5,
+ marker='*', linewidths=0)
+
+ # --- Apply auto-zoom limits (identical whether map was rendered or not) --
+ ax.set_xlim(*xlim)
+ ax.set_ylim(*ylim)
+ ax.set_aspect('equal')
+ ax.grid(True, color='#bbbbbb', alpha=0.4, lw=0.5)
+ ax.tick_params(colors='#444444', labelsize=6)
+ for spine in ax.spines.values():
+ spine.set_color('#cccccc')
+
+ leg_handles = [
+ mpatches.Patch(color=_C_OBS, label='Visible'),
+ mpatches.Patch(color=_C_INVALID, label='Occluded'),
+ mpatches.Patch(color=_C_INTERP, label='Interpolated'),
+ mpatches.Patch(color=_C_EXTRAP, label='Extrapolated'),
+ ]
+ if show_cv:
+ leg_handles.append(mlines.Line2D(
+ [], [], color='#ffd54f', ls='--', lw=1.2, label='CV baseline'))
+ ax.legend(handles=leg_handles, loc='upper right', framealpha=0.85,
+ fontsize=5.5, labelcolor='#111111', facecolor='white',
+ edgecolor='#cccccc', ncol=1)
+
+ cls = track['class_name']
+ tok = track['scene_token'][:6]
+ n_e = track['n_extrap']
+ n_i = track['n_interp']
+ d_e = track.get('extrap_dist', 0.0)
+ d_i = track.get('interp_dist', 0.0)
+ diff = track['max_cv_ml_diff']
+ parts = [f'{cls}', f'{tok}…']
+ if n_i > 0:
+ parts.append(f'{n_i} interp ({d_i:.1f} m)')
+ if n_e > 0:
+ parts.append(f'{n_e} extrap ({d_e:.1f} m)')
+ if show_cv:
+ parts.append(f'max CV Δ={diff:.1f} m')
+ ax.set_title(' | '.join(parts), color='#111111', fontsize=7, pad=3)
+
+
+def _save_single_track_grid(tracks, out_path, suptitle, score_fn,
+ top_n=12, show_cv=False, nusc=None):
+ """Save a 4-column grid of single-track BEV plots ranked by *score_fn*."""
+ import matplotlib.pyplot as plt
+
+ ranked = sorted(tracks, key=score_fn, reverse=True)[:top_n]
+ if not ranked:
+ print(f'No tracks for {out_path}, skipping.')
+ return
+
+ ncols = 4
+ nrows = (len(ranked) + ncols - 1) // ncols
+ fig, axes = plt.subplots(nrows, ncols,
+ figsize=(5 * ncols, 5 * nrows),
+ facecolor=_C_BG)
+ axes = np.array(axes).flatten()
+ for i, track in enumerate(ranked):
+ _visualize_single_track_ax(track, axes[i], show_cv=show_cv, nusc=nusc)
+ for j in range(len(ranked), len(axes)):
+ axes[j].set_visible(False)
+
+ fig.suptitle(suptitle, color='#111111', fontsize=11, y=1.01)
+ fig.subplots_adjust(hspace=0.5, wspace=0.3)
+ plt.savefig(out_path, dpi=200, bbox_inches='tight',
+ facecolor=fig.get_facecolor())
+ print(f'Saved → {out_path}')
+ plt.close(fig)
+
+
+# ===========================================================================
+# Ego-distance histogram helper
+# ===========================================================================
+
+def _count_missing_under_dist(infos, all_scenes, max_dist):
+ """Count instances whose last annotation ends before the scene and within max_dist.
+
+ For each instance in each scene, finds the last frame where it appears
+ (observed or extrapolated). If that frame is not the final scene frame AND
+ the ego distance at that position is below *max_dist*, the instance is
+ counted as still-missing — our augmentation did not reach the scene end.
+
+ Returns
+ -------
+ (n_instances, n_frames) :
+ n_instances — number of such instance-scene pairs.
+ n_frames — total unannotated frame-slots those instances leave behind.
+ """
+ n_instances = 0
+ n_gaps = 0
+ for gidxs in all_scenes:
+ if len(gidxs) <= 1:
+ continue
+ inst_last = {} # inst_ind → (frame_pos, gi, bi) of last annotation
+ for frame_pos, gi in enumerate(gidxs):
+ info = infos[gi]
+ for bi, inst_ind in enumerate(info['instance_inds']):
+ inst_last[inst_ind] = (frame_pos, gi, bi)
+ for inst_ind, (fp, gi, bi) in inst_last.items():
+ if fp >= len(gidxs) - 1:
+ continue # reaches the scene end — not missing
+ box = infos[gi]['gt_boxes'][bi]
+ if float(np.hypot(box[0], box[1])) < max_dist:
+ n_instances += 1
+ n_gaps += len(gidxs) - 1 - fp
+ return n_instances, n_gaps
+
+
+def _save_ego_dist_hist(infos, out_path, max_dist=60, n_missing=None):
+ """Overlay histograms of ego-distance for original vs augmented annotations.
+
+ Boxes are stored in the lidar frame whose origin is the ego vehicle, so
+ ego distance = hypot(box_x, box_y) directly.
+
+ * 'Original' — boxes where both is_interpolated and is_extrapolated are False.
+ * 'Augmented' — all boxes (original + interpolated + extrapolated).
+
+ Parameters
+ ----------
+ max_dist : upper x-axis limit in metres (bins span [0, max_dist]).
+ n_missing : optional (n_instances, n_frames) from ``_count_missing_under_dist``
+ — when provided, annotated as text on the plot.
+ """
+ import matplotlib.pyplot as plt
+
+ orig_dists, aug_dists = [], []
+ for info in infos:
+ boxes = info['gt_boxes']
+ is_i = info['is_interpolated']
+ is_e = info['is_extrapolated']
+ for bi in range(len(info['instance_inds'])):
+ d = float(np.hypot(boxes[bi, 0], boxes[bi, 1]))
+ aug_dists.append(d)
+ if not is_i[bi] and not is_e[bi]:
+ orig_dists.append(d)
+
+ orig_dists = np.array(orig_dists)
+ aug_dists = np.array(aug_dists)
+
+ bins = np.linspace(0, max_dist, 51)
+
+ fig, ax = plt.subplots(figsize=(10, 5), facecolor=_C_BG)
+ ax.set_facecolor(_C_BG)
+ orig_filt = orig_dists[orig_dists <= max_dist]
+ aug_filt = aug_dists[aug_dists <= max_dist]
+ ax.hist(orig_filt, bins=bins, color=_C_OBS, alpha=0.75,
+ label=f'Original (n={len(orig_filt):,})')
+ ax.hist(aug_filt, bins=bins, color=_C_EXTRAP, alpha=0.55,
+ label=f'Augmented (n={len(aug_filt):,})')
+
+ ax.set_xlabel('Ego distance (m)', color='#444444', fontsize=10)
+ ax.set_ylabel('Count', color='#444444', fontsize=10)
+ ax.set_title(
+ f'Ego-distance distribution (0–{max_dist} m) — original vs augmented',
+ color='#111111', fontsize=11, pad=6)
+ ax.tick_params(colors='#444444')
+ ax.legend(labelcolor='#111111', facecolor='white', edgecolor='#cccccc',
+ framealpha=0.9, fontsize=10)
+ for spine in ax.spines.values():
+ spine.set_color('#cccccc')
+ ax.grid(True, color='#bbbbbb', alpha=0.4, linewidth=0.5)
+
+ if n_missing is not None:
+ n_inst, n_frm = n_missing
+ print(f'Still missing (< {max_dist} m): '
+ f'{n_inst:,} instances, {n_frm:,} frame-slots')
+
+ plt.tight_layout()
+ plt.savefig(out_path, dpi=200, bbox_inches='tight',
+ facecolor=fig.get_facecolor())
+ print(f'Saved → {out_path}')
+ plt.close(fig)
+
+
+# ===========================================================================
+# Visualize sub-command
+# ===========================================================================
+
+def cmd_visualize(args):
+ import matplotlib.pyplot as plt
+
+ print(f'Loading {args.pkl} ...')
+ with open(args.pkl, 'rb') as f:
+ data = pickle.load(f)
+ infos = data['infos']
+
+ nusc = None
+ if getattr(args, 'nuscenes_dataroot', None):
+ try:
+ from nuscenes import NuScenes
+ print(f'Loading NuScenes from {args.nuscenes_dataroot} '
+ f'(version={args.nuscenes_version}) for map rendering …')
+ nusc = NuScenes(version=args.nuscenes_version,
+ dataroot=args.nuscenes_dataroot, verbose=False)
+ print(' NuScenes loaded.')
+ except Exception as e:
+ print(f' WARNING: could not load NuScenes ({e}); map will be skipped.')
+
+ scene_map = defaultdict(list)
+ for gi, info in enumerate(infos):
+ scene_map[info['scene_token']].append(gi)
+ for sc in scene_map:
+ scene_map[sc].sort(key=lambda i: infos[i]['timestamp'])
+ all_scenes = list(scene_map.values())
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.scene_idx is not None:
+ scene_list = [all_scenes[args.scene_idx]]
+ fig, ax = plt.subplots(figsize=(10, 10), facecolor=_C_BG)
+ gidxs = scene_list[0]
+ sc_tok = infos[gidxs[0]]['scene_token']
+ _visualize_scene(infos, gidxs, ax, title=f'Scene {sc_tok[:8]}…', nusc=nusc)
+ out = args.output or os.path.join(args.output_dir, 'occ_viz.png')
+ plt.tight_layout()
+ plt.savefig(out, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())
+ print(f'Saved → {out}')
+ plt.close(fig)
+ return
+
+ def _score_interp(gidxs):
+ return sum(int(infos[gi]['is_interpolated'].sum()) for gi in gidxs)
+
+ def _score_extrap(gidxs):
+ return sum(int(infos[gi]['is_extrapolated'].sum()) for gi in gidxs)
+
+ def _score_invalid(gidxs):
+ total = 0
+ for gi in gidxs:
+ info = infos[gi]
+ obs_mask = ~info['is_interpolated'] & ~info['is_extrapolated']
+ total += int((obs_mask & ~info['valid_flag']).sum())
+ return total
+
+ top_interp = sorted(all_scenes, key=_score_interp, reverse=True)[:args.num_scenes]
+ top_extrap = sorted(all_scenes, key=_score_extrap, reverse=True)[:args.num_scenes]
+ top_invalid = sorted(all_scenes, key=_score_invalid, reverse=True)[:args.num_scenes]
+
+ def _save_grid(scene_list, out_path, suptitle, show_forecast=False):
+ ncols = min(3, len(scene_list))
+ nrows = (len(scene_list) + ncols - 1) // ncols
+ fig, axes = plt.subplots(nrows, ncols, figsize=(7*ncols, 7*nrows), facecolor=_C_BG)
+ axes = np.array(axes).flatten()
+ for i, gidxs in enumerate(scene_list):
+ sc_tok = infos[gidxs[0]]['scene_token']
+ _visualize_scene(infos, gidxs, axes[i],
+ title=f'Scene {sc_tok[:8]}…',
+ show_forecast=show_forecast, nusc=nusc)
+ for j in range(len(scene_list), len(axes)):
+ axes[j].set_visible(False)
+ fig.suptitle(suptitle, color='#111111', fontsize=11, y=1.01)
+ fig.subplots_adjust(hspace=0.35, wspace=0.25)
+ plt.savefig(out_path, dpi=300, bbox_inches='tight', facecolor=fig.get_facecolor())
+ print(f'Saved → {out_path}')
+ plt.close(fig)
+
+ _save_grid(top_interp,
+ os.path.join(args.output_dir, 'occ_viz_interp.png'),
+ f'Top {args.num_scenes} scenes by interpolated annotation count')
+ _save_grid(top_extrap,
+ os.path.join(args.output_dir, 'occ_viz_extrap.png'),
+ f'Top {args.num_scenes} scenes by extrapolated annotation count')
+ _save_grid(top_invalid,
+ os.path.join(args.output_dir, 'occ_viz_invalid.png'),
+ f'Top {args.num_scenes} scenes by invalid (0-pt) observed annotation count')
+ # ------------------------------------------------------------------
+ # Single-track grids
+ # ------------------------------------------------------------------
+ print('Collecting per-instance track data …')
+ all_tracks = _collect_all_tracks(infos, all_scenes)
+ extrap_tracks = [t for t in all_tracks if t['n_extrap'] > 0]
+ interp_tracks = [t for t in all_tracks if t['n_interp'] > 0]
+
+ _save_single_track_grid(
+ extrap_tracks,
+ os.path.join(args.output_dir, 'occ_viz_extrap_single.png'),
+ 'Top 12 extrapolated tracks — most distance covered during extrapolation',
+ score_fn=lambda t: t['extrap_dist'],
+ nusc=nusc,
+ )
+ _save_single_track_grid(
+ interp_tracks,
+ os.path.join(args.output_dir, 'occ_viz_interp_single.png'),
+ 'Top 12 interpolated tracks — most distance covered during interpolation',
+ score_fn=lambda t: t['interp_dist'],
+ nusc=nusc,
+ )
+ _save_single_track_grid(
+ extrap_tracks,
+ os.path.join(args.output_dir, 'occ_viz_extrap_single_mismatch.png'),
+ 'Top 12 extrapolated tracks — largest extrapolation-vs-CV position mismatch',
+ score_fn=lambda t: t['max_cv_ml_diff'],
+ show_cv=True,
+ nusc=nusc,
+ )
+ _save_ego_dist_hist(
+ infos,
+ os.path.join(args.output_dir, 'occ_viz_ego_dist.png'),
+ )
+
+
+# ===========================================================================
+# Entry point
+# ===========================================================================
+
+def main():
+ parser = argparse.ArgumentParser(
+ description=__doc__,
+ formatter_class=argparse.RawDescriptionHelpFormatter)
+ sub = parser.add_subparsers(dest='cmd', required=True)
+
+ p_conv = sub.add_parser('convert', help='Run annotation pipeline and save pkl.')
+ p_conv.add_argument('--data-dir', default='data/infos',
+ help='Directory containing the nuScenes info pkls. '
+ 'All three splits (train/val/test) are processed '
+ 'unless --input is given.')
+ p_conv.add_argument('--input', default='data/infos/nuscenes_infos_val.pkl',
+ help='Single input pkl (overrides --data-dir loop)')
+ p_conv.add_argument('--output', default='data/infos/nuscenes_infos_val_occ.pkl',
+ help='Single output pkl (required when --input is set)')
+ p_conv.add_argument('--predictions', default='data/occlusions/fmae_nuscenes_v1trainval_3class_sw_inference.npz',
+ help='Path to UniTraj inference NPZ for ML-based '
+ 'extrapolation; falls back to CV when no match')
+ p_conv.add_argument('--nuscenes-dataroot', default='data/nuscenes',
+ dest='nuscenes_dataroot',
+ help='nuScenes dataset root (contains v1.0-trainval/); '
+ 'required when --predictions is used to map '
+ 'integer instance indices to token strings')
+ p_conv.add_argument('--no-extrapolate', action='store_true',
+ help='Disable forward extrapolation')
+ p_conv.add_argument('--max-extrap-frames', type=int, default=9999,
+ help='Max frames to extrapolate forward per instance. '
+ 'Default 9999 is effectively unlimited — the loop '
+ 'always stops when the instance reappears or the '
+ 'scene ends. Lower this to restrict to the ML '
+ 'prediction horizon (e.g. 12 at 2 Hz).')
+ p_conv.add_argument('--unitraj-dt', type=float, default=0.1,
+ dest='unitraj_dt',
+ help='Seconds per UniTraj prediction step. '
+ '0.1 for 10 Hz models (default), 0.5 for 2 Hz models.')
+ p_conv.add_argument('--cv-stationary-thr', type=float, default=0.3,
+ dest='cv_stationary_thr',
+ help='Speed threshold (m/s) below which the CV fallback '
+ 'uses constant position instead of constant velocity '
+ '(default: 0.3).')
+ # CA / CTR motion-model fitting
+ p_conv.add_argument('--min-ca-history', type=int, default=4,
+ dest='min_ca_history',
+ help='Minimum consecutive observed frames required to fit '
+ 'a CA/CTR model (default: 3).')
+ p_conv.add_argument('--ca-noise-thr', type=float, default=0.5,
+ dest='ca_noise_thr',
+ help='Minimum |acceleration| (m/s²) to apply CA model; '
+ 'below this the track is treated as constant velocity '
+ '(default: 0.5).')
+ p_conv.add_argument('--ca-max-thr', type=float, default=6.0,
+ dest='ca_max_thr',
+ help='Maximum |acceleration| (m/s²) accepted as realistic; '
+ 'above this CA is rejected (default: 5.0).')
+ p_conv.add_argument('--ca-consistency-thr', type=float, default=0.5,
+ dest='ca_consistency_thr',
+ help='Maximum standard deviation of per-interval acceleration '
+ '(m/s²) for the CA model to be accepted (default: 0.5).')
+ p_conv.add_argument('--omega-noise-thr', type=float, default=0.07,
+ dest='omega_noise_thr',
+ help='Minimum |turn rate| (rad/s) to apply CTR model; '
+ 'below this the track is treated as straight '
+ '(default: 0.07 ≈ 4 deg/s).')
+ p_conv.add_argument('--omega-max-thr', type=float, default=0.6,
+ dest='omega_max_thr',
+ help='Maximum |turn rate| (rad/s) accepted as realistic; '
+ 'above this CTR is rejected (default: 0.6 ≈ 34 deg/s).')
+ p_conv.add_argument('--omega-consistency-thr', type=float, default=0.12,
+ dest='omega_consistency_thr',
+ help='Maximum standard deviation of per-interval turn rate '
+ '(rad/s) for the CTR model to be accepted (default: 0.12).')
+ p_conv.add_argument('--max-dist', type=float, default=60.0,
+ dest='max_dist',
+ help='Two-level ego-distance filter applied after conversion: '
+ 'drops instances with no original obs within this range '
+ 'and removes per-frame boxes beyond it. '
+ 'Set to 0 or a negative value to disable (default: 60.0).')
+
+ p_viz = sub.add_parser('visualize', help='BEV plots of occluded annotations.')
+ p_viz.add_argument('--pkl', default='data/infos/nuscenes_infos_val_occ.pkl')
+ p_viz.add_argument('--scene-idx', type=int, default=None,
+ help='Scene index (0-based)')
+ p_viz.add_argument('--num-scenes', type=int, default=6,
+ help='Number of example scenes to plot (default 6)')
+ p_viz.add_argument('--output-dir', default='vis/gt_occ',
+ help='Directory for output images')
+ p_viz.add_argument('--output', default=None,
+ help='Single output filename (overrides --output-dir)')
+ p_viz.add_argument('--nuscenes-dataroot', default='data/nuscenes',
+ dest='nuscenes_dataroot',
+ help='nuScenes dataset root (e.g. /data/sets/nuscenes). '
+ 'When provided, drivable-area map is rendered behind '
+ 'each BEV plot using nusc.explorer.render_ego_centric_map.')
+ p_viz.add_argument('--nuscenes-version', default='v1.0-trainval',
+ dest='nuscenes_version',
+ help='nuScenes version string (default: v1.0-trainval)')
+
+ args = parser.parse_args()
+ if args.cmd == 'convert':
+ if args.input is not None:
+ if args.output is None:
+ parser.error('--output is required when --input is specified')
+ cmd_convert(args)
+ else:
+ _SPLITS = ('train', 'val')
+ for split in _SPLITS:
+ inp = os.path.join(args.data_dir, f'nuscenes_infos_{split}.pkl')
+ out = os.path.join(args.data_dir, f'nuscenes_infos_{split}_occ.pkl')
+ if not os.path.exists(inp):
+ print(f'Skipping {inp} (not found)')
+ continue
+ args.input = inp
+ args.output = out
+ print(f'\n=== {split} ===')
+ cmd_convert(args)
+ else:
+ cmd_visualize(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/dist_test.sh b/tools/dist_test.sh
new file mode 100644
index 0000000..033365e
--- /dev/null
+++ b/tools/dist_test.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+
+CONFIG=$1
+CHECKPOINT=$2
+GPUS=$3
+PORT=${PORT:-29610}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+ $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
diff --git a/tools/dist_train.sh b/tools/dist_train.sh
new file mode 100644
index 0000000..43e95de
--- /dev/null
+++ b/tools/dist_train.sh
@@ -0,0 +1,9 @@
+#!/usr/bin/env bash
+
+CONFIG=$1
+GPUS=$2
+PORT=${PORT:-28651}
+
+PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
+python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
+ $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
diff --git a/tools/fuse_conv_bn.py b/tools/fuse_conv_bn.py
new file mode 100644
index 0000000..9aff402
--- /dev/null
+++ b/tools/fuse_conv_bn.py
@@ -0,0 +1,68 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+
+import torch
+from mmcv.runner import save_checkpoint
+from torch import nn as nn
+
+from mmdet3d.apis import init_model
+
+
+def fuse_conv_bn(conv, bn):
+ """During inference, the functionary of batch norm layers is turned off but
+ only the mean and var alone channels are used, which exposes the chance to
+ fuse it with the preceding conv layers to save computations and simplify
+ network structures."""
+ conv_w = conv.weight
+ conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+ bn.running_mean)
+
+ factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+ conv.weight = nn.Parameter(conv_w *
+ factor.reshape([conv.out_channels, 1, 1, 1]))
+ conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+ return conv
+
+
+def fuse_module(m):
+ last_conv = None
+ last_conv_name = None
+
+ for name, child in m.named_children():
+ if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
+ if last_conv is None: # only fuse BN that is after Conv
+ continue
+ fused_conv = fuse_conv_bn(last_conv, child)
+ m._modules[last_conv_name] = fused_conv
+ # To reduce changes, set BN as Identity instead of deleting it.
+ m._modules[name] = nn.Identity()
+ last_conv = None
+ elif isinstance(child, nn.Conv2d):
+ last_conv = child
+ last_conv_name = name
+ else:
+ fuse_module(child)
+ return m
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='fuse Conv and BN layers in a model')
+ parser.add_argument('config', help='config file path')
+ parser.add_argument('checkpoint', help='checkpoint file path')
+ parser.add_argument('out', help='output path of the converted model')
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+ # build the model from a config file and a checkpoint file
+ model = init_model(args.config, args.checkpoint)
+ # fuse conv and bn layers of the model
+ fused_model = fuse_module(model)
+ save_checkpoint(fused_model, args.out)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/kmeans/kmeans_det.py b/tools/kmeans/kmeans_det.py
new file mode 100644
index 0000000..3c8f218
--- /dev/null
+++ b/tools/kmeans/kmeans_det.py
@@ -0,0 +1,34 @@
+import os
+import pickle
+from tqdm import tqdm
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.cluster import KMeans
+
+import mmcv
+
+os.makedirs('data/kmeans', exist_ok=True)
+os.makedirs('vis/kmeans', exist_ok=True)
+
+K = 900
+DIS_THRESH = 55
+
+fp = 'data/infos/nuscenes_infos_train.pkl'
+data = mmcv.load(fp)
+data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+center = []
+for idx in tqdm(range(len(data_infos))):
+ boxes = data_infos[idx]['gt_boxes'][:,:3]
+ if len(boxes) == 0:
+ continue
+ distance = np.linalg.norm(boxes[:, :2], axis=1)
+ center.append(boxes[distance < DIS_THRESH])
+center = np.concatenate(center, axis=0)
+print("start clustering, may take a few minutes.")
+cluster = KMeans(n_clusters=K).fit(center).cluster_centers_
+plt.scatter(cluster[:,0], cluster[:,1])
+plt.savefig(f'vis/kmeans/det_anchor_{K}', bbox_inches='tight')
+others = np.array([1,1,1,1,0,0,0,0])[np.newaxis].repeat(K, axis=0)
+cluster = np.concatenate([cluster, others], axis=1)
+np.save(f'data/kmeans/kmeans_det_{K}.npy', cluster)
\ No newline at end of file
diff --git a/tools/kmeans/kmeans_map.py b/tools/kmeans/kmeans_map.py
new file mode 100644
index 0000000..13b34bc
--- /dev/null
+++ b/tools/kmeans/kmeans_map.py
@@ -0,0 +1,34 @@
+import os
+import pickle
+from tqdm import tqdm
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.cluster import KMeans
+
+import mmcv
+
+K = 100
+num_sample = 20
+
+fp = 'data/infos/nuscenes_infos_train.pkl'
+data = mmcv.load(fp)
+data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+center = []
+for idx in tqdm(range(len(data_infos))):
+ for cls, geoms in data_infos[idx]["map_annos"].items():
+ for geom in geoms:
+ center.append(geom.mean(axis=0))
+center = np.stack(center, axis=0)
+center = KMeans(n_clusters=K).fit(center).cluster_centers_
+delta_y = np.linspace(-4, 4, num_sample)
+delta_x = np.zeros([num_sample])
+delta = np.stack([delta_x, delta_y], axis=-1)
+vecs = center[:, np.newaxis] + delta[np.newaxis]
+
+for i in range(K):
+ x = vecs[i, :, 0]
+ y = vecs[i, :, 1]
+ plt.plot(x, y, linewidth=1, marker='o', linestyle='-', markersize=2)
+plt.savefig(f'vis/kmeans/map_anchor_{K}', bbox_inches='tight')
+np.save(f'data/kmeans/kmeans_map_{K}.npy', vecs)
\ No newline at end of file
diff --git a/tools/kmeans/kmeans_motion.py b/tools/kmeans/kmeans_motion.py
new file mode 100644
index 0000000..1c5e74c
--- /dev/null
+++ b/tools/kmeans/kmeans_motion.py
@@ -0,0 +1,101 @@
+import os
+import pickle
+from tqdm import tqdm
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.cluster import KMeans
+
+import mmcv
+
+CLASSES = [
+ "car",
+ "truck",
+ "construction_vehicle",
+ "bus",
+ "trailer",
+ "barrier",
+ "motorcycle",
+ "bicycle",
+ "pedestrian",
+ "traffic_cone",
+]
+
+def lidar2agent(trajs_offset, boxes):
+ origin = np.zeros((trajs_offset.shape[0], 1, 2), dtype=np.float32)
+ trajs_offset = np.concatenate([origin, trajs_offset], axis=1)
+ trajs = trajs_offset.cumsum(axis=1)
+ yaws = - boxes[:, 6]
+ rot_sin = np.sin(yaws)
+ rot_cos = np.cos(yaws)
+ rot_mat_T = np.stack(
+ [
+ np.stack([rot_cos, rot_sin]),
+ np.stack([-rot_sin, rot_cos]),
+ ]
+ )
+ trajs_new = np.einsum('aij,jka->aik', trajs, rot_mat_T)
+ trajs_new = trajs_new[:, 1:]
+ return trajs_new
+
+K = 6
+DIS_THRESH = 55
+
+fp = 'data/infos/nuscenes_infos_train.pkl'
+data = mmcv.load(fp)
+data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+intention = dict()
+for i in range(len(CLASSES)):
+ intention[i] = []
+for idx in tqdm(range(len(data_infos))):
+ info = data_infos[idx]
+ boxes = info['gt_boxes']
+ names = info['gt_names']
+ fut_masks = info['gt_agent_fut_masks']
+ trajs = info['gt_agent_fut_trajs']
+ velos = info['gt_velocity']
+ labels = []
+ for cat in names:
+ if cat in CLASSES:
+ labels.append(CLASSES.index(cat))
+ else:
+ labels.append(-1)
+ labels = np.array(labels)
+ if len(boxes) == 0:
+ continue
+ for i in range(len(CLASSES)):
+ cls_mask = (labels == i)
+ box_cls = boxes[cls_mask]
+ fut_masks_cls = fut_masks[cls_mask]
+ trajs_cls = trajs[cls_mask]
+ velos_cls = velos[cls_mask]
+
+ distance = np.linalg.norm(box_cls[:, :2], axis=1)
+ mask = np.logical_and(
+ fut_masks_cls.sum(axis=1) == 12,
+ distance < DIS_THRESH,
+ )
+ trajs_cls = trajs_cls[mask]
+ box_cls = box_cls[mask]
+ velos_cls = velos_cls[mask]
+
+ trajs_agent = lidar2agent(trajs_cls, box_cls)
+ if trajs_agent.shape[0] == 0:
+ continue
+ intention[i].append(trajs_agent)
+
+clusters = []
+for i in range(len(CLASSES)):
+ intention_cls = np.concatenate(intention[i], axis=0).reshape(-1, 24)
+ if intention_cls.shape[0] < K:
+ continue
+ cluster = KMeans(n_clusters=K).fit(intention_cls).cluster_centers_
+ cluster = cluster.reshape(-1, 12, 2)
+ clusters.append(cluster)
+ for j in range(K):
+ plt.scatter(cluster[j, :, 0], cluster[j, :,1])
+ plt.savefig(f'vis/kmeans/motion_intention_{CLASSES[i]}_{K}', bbox_inches='tight')
+ plt.close()
+
+clusters = np.stack(clusters, axis=0)
+np.save(f'data/kmeans/kmeans_motion_{K}.npy', clusters)
\ No newline at end of file
diff --git a/tools/kmeans/kmeans_plan.py b/tools/kmeans/kmeans_plan.py
new file mode 100644
index 0000000..33a74f7
--- /dev/null
+++ b/tools/kmeans/kmeans_plan.py
@@ -0,0 +1,39 @@
+import os
+import pickle
+from tqdm import tqdm
+
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.cluster import KMeans
+
+import mmcv
+
+K = 6
+
+fp = 'data/infos/nuscenes_infos_train.pkl'
+data = mmcv.load(fp)
+data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"]))
+navi_trajs = [[], [], []]
+for idx in tqdm(range(len(data_infos))):
+ info = data_infos[idx]
+ plan_traj = info['gt_ego_fut_trajs'].cumsum(axis=-2)
+ plan_mask = info['gt_ego_fut_masks']
+ cmd = info['gt_ego_fut_cmd'].astype(np.int32)
+ cmd = cmd.argmax(axis=-1)
+ if not plan_mask.sum() == 6:
+ continue
+ navi_trajs[cmd].append(plan_traj)
+
+clusters = []
+for trajs in navi_trajs:
+ trajs = np.concatenate(trajs, axis=0).reshape(-1, 12)
+ cluster = KMeans(n_clusters=K).fit(trajs).cluster_centers_
+ cluster = cluster.reshape(-1, 6, 2)
+ clusters.append(cluster)
+ for j in range(K):
+ plt.scatter(cluster[j, :, 0], cluster[j, :,1])
+plt.savefig(f'vis/kmeans/plan_{K}', bbox_inches='tight')
+plt.close()
+
+clusters = np.stack(clusters, axis=0)
+np.save(f'data/kmeans/kmeans_plan_{K}.npy', clusters)
\ No newline at end of file
diff --git a/tools/test.py b/tools/test.py
new file mode 100644
index 0000000..6a27db2
--- /dev/null
+++ b/tools/test.py
@@ -0,0 +1,334 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import mmcv
+import os
+from os import path as osp
+
+import torch
+import warnings
+from mmcv import Config, DictAction
+from mmcv.cnn import fuse_conv_bn
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (
+ get_dist_info,
+ init_dist,
+ load_checkpoint,
+ wrap_fp16_model,
+)
+
+from mmdet.apis import single_gpu_test, multi_gpu_test, set_random_seed
+from mmdet.datasets import replace_ImageToTensor, build_dataset
+from mmdet.datasets import build_dataloader as build_dataloader_origin
+from mmdet.models import build_detector
+
+from projects.mmdet3d_plugin.datasets.builder import build_dataloader
+from projects.mmdet3d_plugin.apis.test import custom_multi_gpu_test
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description="MMDet test (and eval) a model"
+ )
+ parser.add_argument("config", help="test config file path")
+ parser.add_argument("checkpoint", help="checkpoint file")
+ parser.add_argument("--out", help="output result file in pickle format")
+ parser.add_argument(
+ "--fuse-conv-bn",
+ action="store_true",
+ help="Whether to fuse conv and bn, this will slightly increase"
+ "the inference speed",
+ )
+ parser.add_argument(
+ "--format-only",
+ action="store_true",
+ help="Format the output results without perform evaluation. It is"
+ "useful when you want to format the result to a specific format and "
+ "submit it to the test server",
+ )
+ parser.add_argument(
+ "--eval",
+ type=str,
+ nargs="+",
+ help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
+ ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC',
+ )
+ parser.add_argument("--show", action="store_true", help="show results")
+ parser.add_argument(
+ "--show-dir", help="directory where results will be saved"
+ )
+ parser.add_argument(
+ "--gpu-collect",
+ action="store_true",
+ help="whether to use gpu to collect results.",
+ )
+ parser.add_argument(
+ "--tmpdir",
+ help="tmp directory used for collecting results from multiple "
+ "workers, available when gpu-collect is not specified",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
+ parser.add_argument(
+ "--deterministic",
+ action="store_true",
+ help="whether to set deterministic options for CUDNN backend.",
+ )
+ parser.add_argument(
+ "--cfg-options",
+ nargs="+",
+ action=DictAction,
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file. If the value to "
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ "Note that the quotation marks are necessary and that no white space "
+ "is allowed.",
+ )
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ action=DictAction,
+ help="custom options for evaluation, the key-value pair in xxx=yyy "
+ "format will be kwargs for dataset.evaluate() function (deprecate), "
+ "change to --eval-options instead.",
+ )
+ parser.add_argument(
+ "--eval-options",
+ nargs="+",
+ action=DictAction,
+ help="custom options for evaluation, the key-value pair in xxx=yyy "
+ "format will be kwargs for dataset.evaluate() function",
+ )
+ parser.add_argument(
+ "--launcher",
+ choices=["none", "pytorch", "slurm", "mpi"],
+ default="none",
+ help="job launcher",
+ )
+ parser.add_argument("--local_rank", type=int, default=0)
+ parser.add_argument("--result_file", type=str, default=None)
+ parser.add_argument("--show_only", action="store_true")
+ args = parser.parse_args()
+ if "LOCAL_RANK" not in os.environ:
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
+
+ if args.options and args.eval_options:
+ raise ValueError(
+ "--options and --eval-options cannot be both specified, "
+ "--options is deprecated in favor of --eval-options"
+ )
+ if args.options:
+ warnings.warn("--options is deprecated in favor of --eval-options")
+ args.eval_options = args.options
+ return args
+
+
+def main():
+ args = parse_args()
+
+ assert (
+ args.out or args.eval or args.format_only or args.show or args.show_dir
+ ), (
+ "Please specify at least one operation (save/eval/format/show the "
+ 'results / save the results) with the argument "--out", "--eval"'
+ ', "--format-only", "--show" or "--show-dir"'
+ )
+
+ if args.eval and args.format_only:
+ raise ValueError("--eval and --format_only cannot be both specified")
+
+ if args.out is not None and not args.out.endswith((".pkl", ".pickle")):
+ raise ValueError("The output file must be a pkl file.")
+
+ cfg = Config.fromfile(args.config)
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # import modules from string list.
+ if cfg.get("custom_imports", None):
+ from mmcv.utils import import_modules_from_strings
+
+ import_modules_from_strings(**cfg["custom_imports"])
+
+ # import modules from plguin/xx, registry will be updated
+ if hasattr(cfg, "plugin"):
+ if cfg.plugin:
+ import importlib
+
+ if hasattr(cfg, "plugin_dir"):
+ plugin_dir = cfg.plugin_dir
+ _module_dir = os.path.dirname(plugin_dir)
+ _module_dir = _module_dir.split("/")
+ _module_path = _module_dir[0]
+
+ for m in _module_dir[1:]:
+ _module_path = _module_path + "." + m
+ print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+ else:
+ # import dir is the dirpath for the config file
+ _module_dir = os.path.dirname(args.config)
+ _module_dir = _module_dir.split("/")
+ _module_path = _module_dir[0]
+ for m in _module_dir[1:]:
+ _module_path = _module_path + "." + m
+ print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+
+ # set cudnn_benchmark
+ if cfg.get("cudnn_benchmark", False):
+ torch.backends.cudnn.benchmark = True
+
+ cfg.model.pretrained = None
+ # in case the test dataset is concatenated
+ samples_per_gpu = 1
+ if isinstance(cfg.data.test, dict):
+ cfg.data.test.test_mode = True
+ samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1)
+ if samples_per_gpu > 1:
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.test.pipeline = replace_ImageToTensor(
+ cfg.data.test.pipeline
+ )
+ elif isinstance(cfg.data.test, list):
+ for ds_cfg in cfg.data.test:
+ ds_cfg.test_mode = True
+ samples_per_gpu = max(
+ [ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test]
+ )
+ if samples_per_gpu > 1:
+ for ds_cfg in cfg.data.test:
+ ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == "none":
+ distributed = False
+ else:
+ distributed = True
+ init_dist(args.launcher, **cfg.dist_params)
+
+ # set random seeds
+ if args.seed is not None:
+ set_random_seed(args.seed, deterministic=args.deterministic)
+
+ # set work dir
+ if cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+ mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+ cfg.data.test.work_dir = cfg.work_dir
+ print('work_dir: ',cfg.work_dir)
+
+ # build the dataloader
+ dataset = build_dataset(cfg.data.test)
+ print("distributed:", distributed)
+ if distributed:
+ data_loader = build_dataloader(
+ dataset,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False,
+ nonshuffler_sampler=dict(type="DistributedSampler"),
+ )
+ else:
+ data_loader = build_dataloader_origin(
+ dataset,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False,
+ )
+
+ # build the model and load checkpoint
+ cfg.model.train_cfg = None
+ model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
+ # model = build_model(cfg.model, test_cfg=cfg.get("test_cfg"))
+ fp16_cfg = cfg.get("fp16", None)
+ if fp16_cfg is not None:
+ wrap_fp16_model(model)
+ checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu")
+ if args.fuse_conv_bn:
+ model = fuse_conv_bn(model)
+ # old versions did not save class info in checkpoints, this walkaround is
+ # for backward compatibility
+ if "CLASSES" in checkpoint.get("meta", {}):
+ model.CLASSES = checkpoint["meta"]["CLASSES"]
+ else:
+ model.CLASSES = dataset.CLASSES
+ # palette for visualization in segmentation tasks
+ if "PALETTE" in checkpoint.get("meta", {}):
+ model.PALETTE = checkpoint["meta"]["PALETTE"]
+ elif hasattr(dataset, "PALETTE"):
+ # segmentation dataset has `PALETTE` attribute
+ model.PALETTE = dataset.PALETTE
+
+ if args.result_file is not None:
+ # outputs = torch.load(args.result_file)
+ outputs = mmcv.load(args.result_file)
+ elif not distributed:
+ model = MMDataParallel(model, device_ids=[0])
+ outputs = single_gpu_test(model, data_loader, args.show, args.show_dir)
+ else:
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ )
+ outputs = custom_multi_gpu_test(
+ model, data_loader, args.tmpdir, args.gpu_collect
+ )
+
+ rank, _ = get_dist_info()
+ if rank == 0:
+ if args.out:
+ print(f"\nwriting results to {args.out}")
+ mmcv.dump(outputs, args.out)
+ kwargs = {} if args.eval_options is None else args.eval_options
+ if args.show_only:
+ eval_kwargs = cfg.get("evaluation", {}).copy()
+ # hard-code way to remove EvalHook args
+ for key in [
+ "interval",
+ "tmpdir",
+ "start",
+ "gpu_collect",
+ "save_best",
+ "rule",
+ ]:
+ eval_kwargs.pop(key, None)
+ eval_kwargs.update(kwargs)
+ dataset.show(outputs, show=True, **eval_kwargs)
+ elif args.format_only:
+ dataset.format_results(outputs, **kwargs)
+ elif args.eval:
+ eval_kwargs = cfg.get("evaluation", {}).copy()
+ # hard-code way to remove EvalHook args
+ for key in [
+ "interval",
+ "tmpdir",
+ "start",
+ "gpu_collect",
+ "save_best",
+ "rule",
+ ]:
+ eval_kwargs.pop(key, None)
+ eval_kwargs.update(dict(metric=args.eval, **kwargs))
+ print(eval_kwargs)
+ results_dict = dataset.evaluate(outputs, **eval_kwargs)
+ print(results_dict)
+
+ # Log to wandb if config has WandbLoggerHook
+ for hook in cfg.log_config.hooks:
+ if hook.type == "WandbLoggerHook":
+ import wandb
+ if wandb.run is None:
+ wandb.init(**hook.init_kwargs)
+ wandb.log({'val/' + k: v for k, v in results_dict.items()})
+ break
+
+
+if __name__ == "__main__":
+ torch.multiprocessing.set_start_method(
+ "fork"
+ ) # use fork workers_per_gpu can be > 1
+ main()
diff --git a/tools/train.py b/tools/train.py
new file mode 100644
index 0000000..3ad94a6
--- /dev/null
+++ b/tools/train.py
@@ -0,0 +1,321 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from __future__ import division
+import sys
+import os
+
+print(sys.executable, os.path.abspath(__file__))
+# import init_paths # for conda pkgs submitting method
+import argparse
+import copy
+import mmcv
+import time
+import torch
+import warnings
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist
+from os import path as osp
+
+from mmdet import __version__ as mmdet_version
+from mmdet.apis import train_detector
+from mmdet.datasets import build_dataset
+from mmdet.models import build_detector
+from mmdet.utils import collect_env, get_root_logger
+from mmdet.apis import set_random_seed
+from torch import distributed as dist
+from datetime import timedelta
+
+import cv2
+
+cv2.setNumThreads(8)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train a detector")
+ parser.add_argument("config", help="train config file path")
+ parser.add_argument("--work-dir", help="the dir to save logs and models")
+ parser.add_argument(
+ "--resume-from", help="the checkpoint file to resume from"
+ )
+ parser.add_argument(
+ "--no-validate",
+ action="store_true",
+ help="whether not to evaluate the checkpoint during training",
+ )
+ group_gpus = parser.add_mutually_exclusive_group()
+ group_gpus.add_argument(
+ "--gpus",
+ type=int,
+ help="number of gpus to use "
+ "(only applicable to non-distributed training)",
+ )
+ group_gpus.add_argument(
+ "--gpu-ids",
+ type=int,
+ nargs="+",
+ help="ids of gpus to use "
+ "(only applicable to non-distributed training)",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="random seed")
+ parser.add_argument(
+ "--deterministic",
+ action="store_true",
+ help="whether to set deterministic options for CUDNN backend.",
+ )
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ action=DictAction,
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ parser.add_argument(
+ "--cfg-options",
+ nargs="+",
+ action=DictAction,
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file. If the value to "
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ "Note that the quotation marks are necessary and that no white space "
+ "is allowed.",
+ )
+ parser.add_argument(
+ "--dist-url",
+ type=str,
+ default="auto",
+ help="dist url for init process, such as tcp://localhost:8000",
+ )
+ parser.add_argument("--gpus-per-machine", type=int, default=8)
+ parser.add_argument(
+ "--launcher",
+ choices=["none", "pytorch", "slurm", "mpi", "mpi_nccl"],
+ default="none",
+ help="job launcher",
+ )
+ parser.add_argument("--local_rank", type=int, default=0)
+ parser.add_argument(
+ "--autoscale-lr",
+ action="store_true",
+ help="automatically scale lr with the number of gpus",
+ )
+ args = parser.parse_args()
+ if "LOCAL_RANK" not in os.environ:
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
+
+ if args.options and args.cfg_options:
+ raise ValueError(
+ "--options and --cfg-options cannot be both specified, "
+ "--options is deprecated in favor of --cfg-options"
+ )
+ if args.options:
+ warnings.warn("--options is deprecated in favor of --cfg-options")
+ args.cfg_options = args.options
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # import modules from string list.
+ if cfg.get("custom_imports", None):
+ from mmcv.utils import import_modules_from_strings
+
+ import_modules_from_strings(**cfg["custom_imports"])
+
+ # import modules from plguin/xx, registry will be updated
+ if hasattr(cfg, "plugin"):
+ if cfg.plugin:
+ import importlib
+
+ if hasattr(cfg, "plugin_dir"):
+ plugin_dir = cfg.plugin_dir
+ _module_dir = os.path.dirname(plugin_dir)
+ _module_dir = _module_dir.split("/")
+ _module_path = _module_dir[0]
+
+ for m in _module_dir[1:]:
+ _module_path = _module_path + "." + m
+ print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+ else:
+ # import dir is the dirpath for the config file
+ _module_dir = os.path.dirname(args.config)
+ _module_dir = _module_dir.split("/")
+ _module_path = _module_dir[0]
+ for m in _module_dir[1:]:
+ _module_path = _module_path + "." + m
+ print(_module_path)
+ plg_lib = importlib.import_module(_module_path)
+ from projects.mmdet3d_plugin.apis.train import custom_train_model
+
+ # set cudnn_benchmark
+ if cfg.get("cudnn_benchmark", False):
+ torch.backends.cudnn.benchmark = True
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get("work_dir", None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join(
+ "./work_dirs", osp.splitext(osp.basename(args.config))[0]
+ )
+ if args.resume_from is not None:
+ cfg.resume_from = args.resume_from
+ if args.gpu_ids is not None:
+ cfg.gpu_ids = args.gpu_ids
+ else:
+ cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
+
+ if args.autoscale_lr:
+ # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
+ cfg.optimizer["lr"] = cfg.optimizer["lr"] * len(cfg.gpu_ids) / 8
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == "none":
+ distributed = False
+ elif args.launcher == "mpi_nccl":
+ distributed = True
+
+ import mpi4py.MPI as MPI
+
+ comm = MPI.COMM_WORLD
+ mpi_local_rank = comm.Get_rank()
+ mpi_world_size = comm.Get_size()
+ print(
+ "MPI local_rank=%d, world_size=%d"
+ % (mpi_local_rank, mpi_world_size)
+ )
+
+ # num_gpus = torch.cuda.device_count()
+ device_ids_on_machines = list(range(args.gpus_per_machine))
+ str_ids = list(map(str, device_ids_on_machines))
+ os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str_ids)
+ torch.cuda.set_device(mpi_local_rank % args.gpus_per_machine)
+
+ dist.init_process_group(
+ backend="nccl",
+ init_method=args.dist_url,
+ world_size=mpi_world_size,
+ rank=mpi_local_rank,
+ timeout=timedelta(seconds=7200),
+ )
+
+ cfg.gpu_ids = range(mpi_world_size)
+ print("cfg.gpu_ids:", cfg.gpu_ids)
+ else:
+ distributed = True
+ init_dist(
+ args.launcher, timeout=timedelta(seconds=7200), **cfg.dist_params
+ )
+ # re-set gpu_ids with distributed training mode
+ _, world_size = get_dist_info()
+ cfg.gpu_ids = range(world_size)
+
+ # create work_dir
+ mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+ # dump config
+ cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
+ # init the logger before other steps
+ timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
+ log_file = osp.join(cfg.work_dir, f"{timestamp}.log")
+ # specify logger name, if we still use 'mmdet', the output info will be
+ # filtered and won't be saved in the log_file
+ # TODO: ugly workaround to judge whether we are training det or seg model
+ logger = get_root_logger(
+ log_file=log_file, log_level=cfg.log_level
+ )
+
+ # init the meta dict to record some important information such as
+ # environment info and seed, which will be logged
+ meta = dict()
+ # log env info
+ env_info_dict = collect_env()
+ env_info = "\n".join([(f"{k}: {v}") for k, v in env_info_dict.items()])
+ dash_line = "-" * 60 + "\n"
+ logger.info(
+ "Environment info:\n" + dash_line + env_info + "\n" + dash_line
+ )
+ meta["env_info"] = env_info
+ meta["config"] = cfg.pretty_text
+
+ # log some basic info
+ logger.info(f"Distributed training: {distributed}")
+ logger.info(f"Config:\n{cfg.pretty_text}")
+
+ # set random seeds
+ if args.seed is not None:
+ logger.info(
+ f"Set random seed to {args.seed}, "
+ f"deterministic: {args.deterministic}"
+ )
+ set_random_seed(args.seed, deterministic=args.deterministic)
+ cfg.seed = args.seed
+ meta["seed"] = args.seed
+ meta["exp_name"] = osp.basename(args.config)
+
+ model = build_detector(
+ cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg")
+ )
+ model.init_weights()
+ logger.info(f"Model:\n{model}")
+
+ cfg.data.train.work_dir = cfg.work_dir
+ cfg.data.val.work_dir = cfg.work_dir
+ datasets = [build_dataset(cfg.data.train)]
+
+ if len(cfg.workflow) == 2:
+ val_dataset = copy.deepcopy(cfg.data.val)
+ # in case we use a dataset wrapper
+ if "dataset" in cfg.data.train:
+ val_dataset.pipeline = cfg.data.train.dataset.pipeline
+ else:
+ val_dataset.pipeline = cfg.data.train.pipeline
+ # set test_mode=False here in deep copied config
+ # which do not affect AP/AR calculation later
+ # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
+ val_dataset.test_mode = False
+ datasets.append(build_dataset(val_dataset))
+ if cfg.checkpoint_config is not None:
+ # save mmdet version, config file content and class names in
+ # checkpoints as meta data
+ cfg.checkpoint_config.meta = dict(
+ mmdet_version=mmdet_version,
+ config=cfg.pretty_text,
+ CLASSES=datasets[0].CLASSES,
+ )
+ # add an attribute for visualization convenience
+ model.CLASSES = datasets[0].CLASSES
+ if hasattr(cfg, "plugin"):
+ custom_train_model(
+ model,
+ datasets,
+ cfg,
+ distributed=distributed,
+ validate=(not args.no_validate),
+ timestamp=timestamp,
+ meta=meta,
+ )
+ else:
+ train_detector(
+ model,
+ datasets,
+ cfg,
+ distributed=distributed,
+ validate=(not args.no_validate),
+ timestamp=timestamp,
+ meta=meta,
+ )
+
+
+if __name__ == "__main__":
+ torch.multiprocessing.set_start_method(
+ "fork"
+ ) # use fork workers_per_gpu can be > 1
+ main()
diff --git a/tools/upload_wandb_results.py b/tools/upload_wandb_results.py
new file mode 100644
index 0000000..e33ff31
--- /dev/null
+++ b/tools/upload_wandb_results.py
@@ -0,0 +1,26 @@
+# Uploads pre-run test results to wandb as individual runs.
+# Each entry in results_dict is a run name -> metrics dict.
+# Run: python tools/log_tests.py
+
+import wandb
+from math import nan
+
+# wandb project to log to (name is set per-run below)
+init_kwargs = dict(
+ entity='trailab',
+ project='ForeSight',
+ name=None,
+)
+
+# Add entries as: results_dict[''] = {}
+results_dict = {}
+results_dict['sparsedrive_r50_stage2_4gpu_nomap_novalmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3508, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.6027, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.7565, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.8333, 'img_bbox_NuScenes/car_trans_err': 0.3839, 'img_bbox_NuScenes/car_scale_err': 0.1462, 'img_bbox_NuScenes/car_orient_err': 0.0745, 'img_bbox_NuScenes/car_vel_err': 0.2122, 'img_bbox_NuScenes/car_attr_err': 0.1931, 'img_bbox_NuScenes/mATE': 0.559, 'img_bbox_NuScenes/mASE': 0.273, 'img_bbox_NuScenes/mAOE': 0.5665, 'img_bbox_NuScenes/mAVE': 0.2748, 'img_bbox_NuScenes/mAAE': 0.1843, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0822, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2492, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4667, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5558, 'img_bbox_NuScenes/truck_trans_err': 0.6161, 'img_bbox_NuScenes/truck_scale_err': 0.1997, 'img_bbox_NuScenes/truck_orient_err': 0.14, 'img_bbox_NuScenes/truck_vel_err': 0.1974, 'img_bbox_NuScenes/truck_attr_err': 0.1849, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0171, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.0804, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.1867, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.9051, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.5006, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.3696, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1487, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.3778, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.057, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.2654, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.4994, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.6913, 'img_bbox_NuScenes/bus_trans_err': 0.6795, 'img_bbox_NuScenes/bus_scale_err': 0.2195, 'img_bbox_NuScenes/bus_orient_err': 0.1293, 'img_bbox_NuScenes/bus_vel_err': 0.4734, 'img_bbox_NuScenes/bus_attr_err': 0.2621, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0295, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1291, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2746, 'img_bbox_NuScenes/trailer_trans_err': 0.9455, 'img_bbox_NuScenes/trailer_scale_err': 0.2619, 'img_bbox_NuScenes/trailer_orient_err': 0.6304, 'img_bbox_NuScenes/trailer_vel_err': 0.3321, 'img_bbox_NuScenes/trailer_attr_err': 0.037, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.362, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.5598, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.669, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.7042, 'img_bbox_NuScenes/barrier_trans_err': 0.3256, 'img_bbox_NuScenes/barrier_scale_err': 0.2842, 'img_bbox_NuScenes/barrier_orient_err': 0.1083, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1546, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.4048, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5362, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.6075, 'img_bbox_NuScenes/motorcycle_trans_err': 0.5019, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2554, 'img_bbox_NuScenes/motorcycle_orient_err': 0.8037, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3619, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2457, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2086, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4115, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4935, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.5589, 'img_bbox_NuScenes/bicycle_trans_err': 0.4162, 'img_bbox_NuScenes/bicycle_scale_err': 0.2588, 'img_bbox_NuScenes/bicycle_orient_err': 1.2732, 'img_bbox_NuScenes/bicycle_vel_err': 0.1427, 'img_bbox_NuScenes/bicycle_attr_err': 0.0065, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1827, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4434, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.66, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7615, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5612, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2874, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5691, 'img_bbox_NuScenes/pedestrian_vel_err': 0.3303, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1671, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.527, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6597, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7376, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.7733, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2547, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.316, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5210452731598937, 'img_bbox_NuScenes/mAP': 0.4135818961243956, 'img_bbox_NuScenes/amota': 0.37841917596535796, 'img_bbox_NuScenes/amotp': 1.254887412737341, 'img_bbox_NuScenes/recall': 0.48778598773767107, 'img_bbox_NuScenes/motar': 0.6749810391493056, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.34960712411451783, 'img_bbox_NuScenes/motp': 0.6241441130285599, 'img_bbox_NuScenes/mt': 2432.0, 'img_bbox_NuScenes/ml': 2342.0, 'img_bbox_NuScenes/faf': 45.385501647071784, 'img_bbox_NuScenes/tp': 60681.0, 'img_bbox_NuScenes/fp': 12701.0, 'img_bbox_NuScenes/fn': 40297.0, 'img_bbox_NuScenes/ids': 919.0, 'img_bbox_NuScenes/frag': 574.0, 'img_bbox_NuScenes/tid': 1.98160291062767, 'img_bbox_NuScenes/lgd': 2.5866146019633036, 'car_EPA': 0.4893063663779985, 'pedestrian_EPA': 0.4168936170212766, 'car_min_ade_err': 0.6188359283156746, 'car_min_fde_err': 0.9741884232379546, 'car_miss_rate_err': 0.1353851265872359, 'pedestrian_min_ade_err': 0.7137598055143394, 'pedestrian_min_fde_err': 1.0442730012917574, 'pedestrian_miss_rate_err': 0.14390649673648942, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0008248138858309378, 'L2': 0.5903816537724601}
+results_dict['sparsedrive_r50_stage2_4gpu_nomap_notrainmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3442, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.5914, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.7521, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.8271, 'img_bbox_NuScenes/car_trans_err': 0.3811, 'img_bbox_NuScenes/car_scale_err': 0.1476, 'img_bbox_NuScenes/car_orient_err': 0.072, 'img_bbox_NuScenes/car_vel_err': 0.2106, 'img_bbox_NuScenes/car_attr_err': 0.1925, 'img_bbox_NuScenes/mATE': 0.5684, 'img_bbox_NuScenes/mASE': 0.2723, 'img_bbox_NuScenes/mAOE': 0.5279, 'img_bbox_NuScenes/mAVE': 0.2559, 'img_bbox_NuScenes/mAAE': 0.1894, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0966, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2541, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4585, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5452, 'img_bbox_NuScenes/truck_trans_err': 0.5826, 'img_bbox_NuScenes/truck_scale_err': 0.2002, 'img_bbox_NuScenes/truck_orient_err': 0.1221, 'img_bbox_NuScenes/truck_vel_err': 0.1876, 'img_bbox_NuScenes/truck_attr_err': 0.1921, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0159, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.1082, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.232, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.9415, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.4728, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.2952, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1314, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.4037, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.0683, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.2488, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.5065, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.6973, 'img_bbox_NuScenes/bus_trans_err': 0.6729, 'img_bbox_NuScenes/bus_scale_err': 0.2401, 'img_bbox_NuScenes/bus_orient_err': 0.1349, 'img_bbox_NuScenes/bus_vel_err': 0.4418, 'img_bbox_NuScenes/bus_attr_err': 0.256, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0152, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1115, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2447, 'img_bbox_NuScenes/trailer_trans_err': 1.026, 'img_bbox_NuScenes/trailer_scale_err': 0.2672, 'img_bbox_NuScenes/trailer_orient_err': 0.7522, 'img_bbox_NuScenes/trailer_vel_err': 0.2828, 'img_bbox_NuScenes/trailer_attr_err': 0.0403, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.3544, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.5654, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.6728, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.7101, 'img_bbox_NuScenes/barrier_trans_err': 0.3321, 'img_bbox_NuScenes/barrier_scale_err': 0.2847, 'img_bbox_NuScenes/barrier_orient_err': 0.115, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1725, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.3832, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5372, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.6004, 'img_bbox_NuScenes/motorcycle_trans_err': 0.5145, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2535, 'img_bbox_NuScenes/motorcycle_orient_err': 0.6521, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3426, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2553, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2015, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4043, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4961, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.5502, 'img_bbox_NuScenes/bicycle_trans_err': 0.4078, 'img_bbox_NuScenes/bicycle_scale_err': 0.254, 'img_bbox_NuScenes/bicycle_orient_err': 1.0339, 'img_bbox_NuScenes/bicycle_vel_err': 0.1277, 'img_bbox_NuScenes/bicycle_attr_err': 0.0091, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1789, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4351, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.6511, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7558, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5605, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2881, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5734, 'img_bbox_NuScenes/pedestrian_vel_err': 0.323, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1662, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.5255, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6562, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7554, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.7894, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2654, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.3147, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5250229358032298, 'img_bbox_NuScenes/mAP': 0.41282793609219165, 'img_bbox_NuScenes/amota': 0.37481117124367386, 'img_bbox_NuScenes/amotp': 1.243938826625982, 'img_bbox_NuScenes/recall': 0.5012763587171442, 'img_bbox_NuScenes/motar': 0.6331981367976215, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.3471572990666577, 'img_bbox_NuScenes/motp': 0.652721931604462, 'img_bbox_NuScenes/mt': 2425.0, 'img_bbox_NuScenes/ml': 2417.0, 'img_bbox_NuScenes/faf': 49.93009742722906, 'img_bbox_NuScenes/tp': 60614.0, 'img_bbox_NuScenes/fp': 13965.0, 'img_bbox_NuScenes/fn': 40432.0, 'img_bbox_NuScenes/ids': 851.0, 'img_bbox_NuScenes/frag': 554.0, 'img_bbox_NuScenes/tid': 1.8597660707921904, 'img_bbox_NuScenes/lgd': 2.312623196734337, 'car_EPA': 0.4765860609246604, 'pedestrian_EPA': 0.4042340425531915, 'car_min_ade_err': 0.6226734707457249, 'car_min_fde_err': 0.9728587197855352, 'car_miss_rate_err': 0.1323666881529005, 'pedestrian_min_ade_err': 0.7220416654547696, 'pedestrian_min_fde_err': 1.0578391945840033, 'pedestrian_miss_rate_err': 0.14206930953249303, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0009984588925565023, 'L2': 0.606941523651282}
+results_dict['sparsedrive_r50_stage2_4gpu_bs24_notrainmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3508, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.5974, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.7558, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.831, 'img_bbox_NuScenes/car_trans_err': 0.3742, 'img_bbox_NuScenes/car_scale_err': 0.1472, 'img_bbox_NuScenes/car_orient_err': 0.0673, 'img_bbox_NuScenes/car_vel_err': 0.2134, 'img_bbox_NuScenes/car_attr_err': 0.1939, 'img_bbox_NuScenes/mATE': 0.5647, 'img_bbox_NuScenes/mASE': 0.273, 'img_bbox_NuScenes/mAOE': 0.5172, 'img_bbox_NuScenes/mAVE': 0.2548, 'img_bbox_NuScenes/mAAE': 0.1856, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0968, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2853, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4715, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5704, 'img_bbox_NuScenes/truck_trans_err': 0.5657, 'img_bbox_NuScenes/truck_scale_err': 0.2025, 'img_bbox_NuScenes/truck_orient_err': 0.1376, 'img_bbox_NuScenes/truck_vel_err': 0.1827, 'img_bbox_NuScenes/truck_attr_err': 0.1764, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0164, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.1006, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.2256, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.9397, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.4907, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.2775, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1276, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.3784, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.0578, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.2705, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.506, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.7047, 'img_bbox_NuScenes/bus_trans_err': 0.6669, 'img_bbox_NuScenes/bus_scale_err': 0.2114, 'img_bbox_NuScenes/bus_orient_err': 0.142, 'img_bbox_NuScenes/bus_vel_err': 0.4896, 'img_bbox_NuScenes/bus_attr_err': 0.264, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0188, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1256, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2585, 'img_bbox_NuScenes/trailer_trans_err': 1.0126, 'img_bbox_NuScenes/trailer_scale_err': 0.272, 'img_bbox_NuScenes/trailer_orient_err': 0.6611, 'img_bbox_NuScenes/trailer_vel_err': 0.1953, 'img_bbox_NuScenes/trailer_attr_err': 0.0356, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.3398, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.551, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.6489, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.6994, 'img_bbox_NuScenes/barrier_trans_err': 0.3263, 'img_bbox_NuScenes/barrier_scale_err': 0.2811, 'img_bbox_NuScenes/barrier_orient_err': 0.1107, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1709, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.3727, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5369, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.592, 'img_bbox_NuScenes/motorcycle_trans_err': 0.5122, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2601, 'img_bbox_NuScenes/motorcycle_orient_err': 0.732, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3777, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2651, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2191, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4086, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4922, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.5613, 'img_bbox_NuScenes/bicycle_trans_err': 0.4064, 'img_bbox_NuScenes/bicycle_scale_err': 0.2524, 'img_bbox_NuScenes/bicycle_orient_err': 0.9753, 'img_bbox_NuScenes/bicycle_vel_err': 0.1296, 'img_bbox_NuScenes/bicycle_attr_err': 0.0068, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1718, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4335, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.6497, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7572, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5675, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2907, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5514, 'img_bbox_NuScenes/pedestrian_vel_err': 0.3226, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1644, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.4908, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6467, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7377, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.7785, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2753, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.3223, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5267472552418652, 'img_bbox_NuScenes/mAP': 0.41255745286416134, 'img_bbox_NuScenes/amota': 0.3751674076861607, 'img_bbox_NuScenes/amotp': 1.2309562814926793, 'img_bbox_NuScenes/recall': 0.529213083568733, 'img_bbox_NuScenes/motar': 0.6298597506597094, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.3433148506844225, 'img_bbox_NuScenes/motp': 0.652450085163147, 'img_bbox_NuScenes/mt': 2494.0, 'img_bbox_NuScenes/ml': 2449.0, 'img_bbox_NuScenes/faf': 76.95633001721761, 'img_bbox_NuScenes/tp': 61285.0, 'img_bbox_NuScenes/fp': 16021.0, 'img_bbox_NuScenes/fn': 39960.0, 'img_bbox_NuScenes/ids': 652.0, 'img_bbox_NuScenes/frag': 577.0, 'img_bbox_NuScenes/tid': 1.6894107547316926, 'img_bbox_NuScenes/lgd': 2.249098392132382, 'ped_crossing': 0.5106424816863883, 'divider': 0.566774188397169, 'boundary': 0.5805887420239092, 'mAP_normal': 0.5526684707024888, 'car_EPA': 0.48456132584456185, 'pedestrian_EPA': 0.40117021276595743, 'car_min_ade_err': 0.6271592403240047, 'car_min_fde_err': 0.9811241473554299, 'car_miss_rate_err': 0.13298819295968337, 'pedestrian_min_ade_err': 0.7224249630153603, 'pedestrian_min_fde_err': 1.0487785643782717, 'pedestrian_miss_rate_err': 0.14350357990815366, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0013348961696869489, 'L2': 0.6096154629356332}
+results_dict['sparsedrive_r50_stage2_4gpu_bs24_novalmask'] = {'img_bbox_NuScenes/car_AP_dist_0.5': 0.3538, 'img_bbox_NuScenes/car_AP_dist_1.0': 0.6001, 'img_bbox_NuScenes/car_AP_dist_2.0': 0.758, 'img_bbox_NuScenes/car_AP_dist_4.0': 0.8323, 'img_bbox_NuScenes/car_trans_err': 0.3792, 'img_bbox_NuScenes/car_scale_err': 0.1473, 'img_bbox_NuScenes/car_orient_err': 0.069, 'img_bbox_NuScenes/car_vel_err': 0.2209, 'img_bbox_NuScenes/car_attr_err': 0.1944, 'img_bbox_NuScenes/mATE': 0.5547, 'img_bbox_NuScenes/mASE': 0.2744, 'img_bbox_NuScenes/mAOE': 0.5532, 'img_bbox_NuScenes/mAVE': 0.2691, 'img_bbox_NuScenes/mAAE': 0.1896, 'img_bbox_NuScenes/truck_AP_dist_0.5': 0.0902, 'img_bbox_NuScenes/truck_AP_dist_1.0': 0.2458, 'img_bbox_NuScenes/truck_AP_dist_2.0': 0.4512, 'img_bbox_NuScenes/truck_AP_dist_4.0': 0.5459, 'img_bbox_NuScenes/truck_trans_err': 0.5886, 'img_bbox_NuScenes/truck_scale_err': 0.2, 'img_bbox_NuScenes/truck_orient_err': 0.1141, 'img_bbox_NuScenes/truck_vel_err': 0.2054, 'img_bbox_NuScenes/truck_attr_err': 0.1899, 'img_bbox_NuScenes/construction_vehicle_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/construction_vehicle_AP_dist_1.0': 0.0195, 'img_bbox_NuScenes/construction_vehicle_AP_dist_2.0': 0.0852, 'img_bbox_NuScenes/construction_vehicle_AP_dist_4.0': 0.2121, 'img_bbox_NuScenes/construction_vehicle_trans_err': 0.8767, 'img_bbox_NuScenes/construction_vehicle_scale_err': 0.4977, 'img_bbox_NuScenes/construction_vehicle_orient_err': 1.5706, 'img_bbox_NuScenes/construction_vehicle_vel_err': 0.1386, 'img_bbox_NuScenes/construction_vehicle_attr_err': 0.4074, 'img_bbox_NuScenes/bus_AP_dist_0.5': 0.036, 'img_bbox_NuScenes/bus_AP_dist_1.0': 0.257, 'img_bbox_NuScenes/bus_AP_dist_2.0': 0.4972, 'img_bbox_NuScenes/bus_AP_dist_4.0': 0.6937, 'img_bbox_NuScenes/bus_trans_err': 0.6888, 'img_bbox_NuScenes/bus_scale_err': 0.2271, 'img_bbox_NuScenes/bus_orient_err': 0.134, 'img_bbox_NuScenes/bus_vel_err': 0.4766, 'img_bbox_NuScenes/bus_attr_err': 0.2613, 'img_bbox_NuScenes/trailer_AP_dist_0.5': 0.0, 'img_bbox_NuScenes/trailer_AP_dist_1.0': 0.0222, 'img_bbox_NuScenes/trailer_AP_dist_2.0': 0.1196, 'img_bbox_NuScenes/trailer_AP_dist_4.0': 0.2607, 'img_bbox_NuScenes/trailer_trans_err': 0.9515, 'img_bbox_NuScenes/trailer_scale_err': 0.2659, 'img_bbox_NuScenes/trailer_orient_err': 0.7065, 'img_bbox_NuScenes/trailer_vel_err': 0.264, 'img_bbox_NuScenes/trailer_attr_err': 0.0367, 'img_bbox_NuScenes/barrier_AP_dist_0.5': 0.3474, 'img_bbox_NuScenes/barrier_AP_dist_1.0': 0.561, 'img_bbox_NuScenes/barrier_AP_dist_2.0': 0.6808, 'img_bbox_NuScenes/barrier_AP_dist_4.0': 0.7168, 'img_bbox_NuScenes/barrier_trans_err': 0.3446, 'img_bbox_NuScenes/barrier_scale_err': 0.2799, 'img_bbox_NuScenes/barrier_orient_err': 0.1079, 'img_bbox_NuScenes/barrier_vel_err': nan, 'img_bbox_NuScenes/barrier_attr_err': nan, 'img_bbox_NuScenes/motorcycle_AP_dist_0.5': 0.1776, 'img_bbox_NuScenes/motorcycle_AP_dist_1.0': 0.3563, 'img_bbox_NuScenes/motorcycle_AP_dist_2.0': 0.5189, 'img_bbox_NuScenes/motorcycle_AP_dist_4.0': 0.5853, 'img_bbox_NuScenes/motorcycle_trans_err': 0.4979, 'img_bbox_NuScenes/motorcycle_scale_err': 0.2571, 'img_bbox_NuScenes/motorcycle_orient_err': 0.7239, 'img_bbox_NuScenes/motorcycle_vel_err': 0.3656, 'img_bbox_NuScenes/motorcycle_attr_err': 0.2564, 'img_bbox_NuScenes/bicycle_AP_dist_0.5': 0.2349, 'img_bbox_NuScenes/bicycle_AP_dist_1.0': 0.4202, 'img_bbox_NuScenes/bicycle_AP_dist_2.0': 0.4961, 'img_bbox_NuScenes/bicycle_AP_dist_4.0': 0.56, 'img_bbox_NuScenes/bicycle_trans_err': 0.396, 'img_bbox_NuScenes/bicycle_scale_err': 0.2673, 'img_bbox_NuScenes/bicycle_orient_err': 0.9788, 'img_bbox_NuScenes/bicycle_vel_err': 0.1505, 'img_bbox_NuScenes/bicycle_attr_err': 0.0061, 'img_bbox_NuScenes/pedestrian_AP_dist_0.5': 0.1824, 'img_bbox_NuScenes/pedestrian_AP_dist_1.0': 0.4457, 'img_bbox_NuScenes/pedestrian_AP_dist_2.0': 0.6639, 'img_bbox_NuScenes/pedestrian_AP_dist_4.0': 0.7674, 'img_bbox_NuScenes/pedestrian_trans_err': 0.5609, 'img_bbox_NuScenes/pedestrian_scale_err': 0.2893, 'img_bbox_NuScenes/pedestrian_orient_err': 0.5742, 'img_bbox_NuScenes/pedestrian_vel_err': 0.3311, 'img_bbox_NuScenes/pedestrian_attr_err': 0.1647, 'img_bbox_NuScenes/traffic_cone_AP_dist_0.5': 0.5225, 'img_bbox_NuScenes/traffic_cone_AP_dist_1.0': 0.6764, 'img_bbox_NuScenes/traffic_cone_AP_dist_2.0': 0.7529, 'img_bbox_NuScenes/traffic_cone_AP_dist_4.0': 0.789, 'img_bbox_NuScenes/traffic_cone_trans_err': 0.2632, 'img_bbox_NuScenes/traffic_cone_scale_err': 0.3128, 'img_bbox_NuScenes/traffic_cone_orient_err': nan, 'img_bbox_NuScenes/traffic_cone_vel_err': nan, 'img_bbox_NuScenes/traffic_cone_attr_err': nan, 'img_bbox_NuScenes/NDS': 0.5225890076722248, 'img_bbox_NuScenes/mAP': 0.4133932201012164, 'img_bbox_NuScenes/amota': 0.3754069553264074, 'img_bbox_NuScenes/amotp': 1.2561417111667958, 'img_bbox_NuScenes/recall': 0.5258422088207545, 'img_bbox_NuScenes/motar': 0.6338878086993166, 'img_bbox_NuScenes/gt': 14556.714285714286, 'img_bbox_NuScenes/mota': 0.34343638941858706, 'img_bbox_NuScenes/motp': 0.6429838895199319, 'img_bbox_NuScenes/mt': 2495.0, 'img_bbox_NuScenes/ml': 2226.0, 'img_bbox_NuScenes/faf': 75.21906140713715, 'img_bbox_NuScenes/tp': 61706.0, 'img_bbox_NuScenes/fp': 15583.0, 'img_bbox_NuScenes/fn': 39092.0, 'img_bbox_NuScenes/ids': 1099.0, 'img_bbox_NuScenes/frag': 635.0, 'img_bbox_NuScenes/tid': 1.6866828505090832, 'img_bbox_NuScenes/lgd': 2.2873464775229038, 'ped_crossing': 0.48785861624040217, 'divider': 0.5796305207181903, 'boundary': 0.5912029467581585, 'mAP_normal': 0.552897361238917, 'car_EPA': 0.49216307445425117, 'pedestrian_EPA': 0.41306382978723405, 'car_min_ade_err': 0.6361001780131743, 'car_min_fde_err': 1.0011809885715843, 'car_miss_rate_err': 0.13309976946893648, 'pedestrian_min_ade_err': 0.729691258942758, 'pedestrian_min_fde_err': 1.0665211591121826, 'pedestrian_miss_rate_err': 0.14577341344157752, 'obj_col': 0.006701612793323066, 'obj_box_col': 0.0013348961502843953, 'L2': 0.6358533894850148}
+
+for key, value in results_dict.items():
+ init_kwargs['name'] = key
+ wandb.init(**init_kwargs)
+ wandb.log({'val/' + k: v for k, v in value.items()})
+ wandb.finish()
diff --git a/tools/visualization/bev_render.py b/tools/visualization/bev_render.py
new file mode 100644
index 0000000..35db053
--- /dev/null
+++ b/tools/visualization/bev_render.py
@@ -0,0 +1,408 @@
+import os
+import numpy as np
+import cv2
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners
+
+CMD_LIST = ['Turn Right', 'Turn Left', 'Go Straight']
+COLOR_VECTORS = ['cornflowerblue', 'royalblue', 'slategrey']
+SCORE_THRESH = 0.3
+MAP_SCORE_THRESH = 0.3
+color_mapping = np.asarray([
+ [0, 0, 0],
+ [255, 179, 0],
+ [128, 62, 117],
+ [255, 104, 0],
+ [166, 189, 215],
+ [193, 0, 32],
+ [206, 162, 98],
+ [129, 112, 102],
+ [0, 125, 52],
+ [246, 118, 142],
+ [0, 83, 138],
+ [255, 122, 92],
+ [83, 55, 122],
+ [255, 142, 0],
+ [179, 40, 81],
+ [244, 200, 0],
+ [127, 24, 13],
+ [147, 170, 0],
+ [89, 51, 21],
+ [241, 58, 19],
+ [35, 44, 22],
+ [112, 224, 255],
+ [70, 184, 160],
+ [153, 0, 255],
+ [71, 255, 0],
+ [255, 0, 163],
+ [255, 204, 0],
+ [0, 255, 235],
+ [255, 0, 235],
+ [255, 0, 122],
+ [255, 245, 0],
+ [10, 190, 212],
+ [214, 255, 0],
+ [0, 204, 255],
+ [20, 0, 255],
+ [255, 255, 0],
+ [0, 153, 255],
+ [0, 255, 204],
+ [41, 255, 0],
+ [173, 0, 255],
+ [0, 245, 255],
+ [71, 0, 255],
+ [0, 255, 184],
+ [0, 92, 255],
+ [184, 255, 0],
+ [255, 214, 0],
+ [25, 194, 194],
+ [92, 0, 255],
+ [220, 220, 220],
+ [255, 9, 92],
+ [112, 9, 255],
+ [8, 255, 214],
+ [255, 184, 6],
+ [10, 255, 71],
+ [255, 41, 10],
+ [7, 255, 255],
+ [224, 255, 8],
+ [102, 8, 255],
+ [255, 61, 6],
+ [255, 194, 7],
+ [0, 255, 20],
+ [255, 8, 41],
+ [255, 5, 153],
+ [6, 51, 255],
+ [235, 12, 255],
+ [160, 150, 20],
+ [0, 163, 255],
+ [140, 140, 140],
+ [250, 10, 15],
+ [20, 255, 0],
+]) / 255
+
+
+class BEVRender:
+ def __init__(
+ self,
+ plot_choices,
+ out_dir,
+ xlim = 40,
+ ylim = 40,
+ ):
+ self.plot_choices = plot_choices
+ self.xlim = xlim
+ self.ylim = ylim
+ self.gt_dir = os.path.join(out_dir, "bev_gt")
+ self.pred_dir = os.path.join(out_dir, "bev_pred")
+ os.makedirs(self.gt_dir, exist_ok=True)
+ os.makedirs(self.pred_dir, exist_ok=True)
+
+ def reset_canvas(self):
+ plt.close()
+ self.fig, self.axes = plt.subplots(1, 1, figsize=(20, 20))
+ self.axes.set_xlim(- self.xlim, self.xlim)
+ self.axes.set_ylim(- self.ylim, self.ylim)
+ self.axes.axis('off')
+
+ def render(
+ self,
+ data,
+ result,
+ index,
+ ):
+ self.reset_canvas()
+ self.draw_detection_gt(data)
+ self.draw_motion_gt(data)
+ self.draw_map_gt(data)
+ self.draw_planning_gt(data)
+ self._render_sdc_car()
+ self._render_command(data)
+ self._render_legend()
+ save_path_gt = os.path.join(self.gt_dir, str(index).zfill(4) + '.jpg')
+ self.save_fig(save_path_gt)
+
+ self.reset_canvas()
+ self.draw_detection_pred(result)
+ self.draw_track_pred(result)
+ self.draw_motion_pred(result)
+ self.draw_map_pred(result)
+ self.draw_planning_pred(data, result)
+ self._render_sdc_car()
+ self._render_command(data)
+ self._render_legend()
+ save_path_pred = os.path.join(self.pred_dir, str(index).zfill(4) + '.jpg')
+ self.save_fig(save_path_pred)
+
+ return save_path_gt, save_path_pred
+
+ def save_fig(self, filename):
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0,
+ hspace=0, wspace=0)
+ plt.margins(0, 0)
+ plt.savefig(filename)
+
+ def draw_detection_gt(self, data):
+ if not self.plot_choices['det']:
+ return
+
+ for i in range(data['gt_labels_3d'].shape[0]):
+ label = data['gt_labels_3d'][i]
+ if label == -1:
+ continue
+ color = color_mapping[i % len(color_mapping)]
+
+ # draw corners
+ corners = box3d_to_corners(data['gt_bboxes_3d'])[i, [0, 3, 7, 4, 0]]
+ x = corners[:, 0]
+ y = corners[:, 1]
+ self.axes.plot(x, y, color=color, linewidth=3, linestyle='-')
+
+ # draw line to indicate forward direction
+ forward_center = np.mean(corners[2:4], axis=0)
+ center = np.mean(corners[0:4], axis=0)
+ x = [forward_center[0], center[0]]
+ y = [forward_center[1], center[1]]
+ self.axes.plot(x, y, color=color, linewidth=3, linestyle='-')
+
+ def draw_detection_pred(self, result):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['det'] and "boxes_3d" in result):
+ return
+
+ bboxes = result['boxes_3d']
+ for i in range(result['labels_3d'].shape[0]):
+ score = result['scores_3d'][i]
+ if score < SCORE_THRESH:
+ continue
+ color = color_mapping[result['instance_ids'][i] % len(color_mapping)]
+
+ # draw corners
+ corners = box3d_to_corners(bboxes)[i, [0, 3, 7, 4, 0]]
+ x = corners[:, 0]
+ y = corners[:, 1]
+ self.axes.plot(x, y, color=color, linewidth=3, linestyle='-')
+
+ # draw line to indicate forward direction
+ forward_center = np.mean(corners[2:4], axis=0)
+ center = np.mean(corners[0:4], axis=0)
+ x = [forward_center[0], center[0]]
+ y = [forward_center[1], center[1]]
+ self.axes.plot(x, y, color=color, linewidth=3, linestyle='-')
+
+ def draw_track_pred(self, result):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['track'] and "anchor_queue" in result):
+ return
+
+ temp_bboxes = result["anchor_queue"]
+ period = result["period"]
+ bboxes = result['boxes_3d']
+ for i in range(result['labels_3d'].shape[0]):
+ score = result['scores_3d'][i]
+ if score < SCORE_THRESH:
+ continue
+ color = color_mapping[result['instance_ids'][i] % len(color_mapping)]
+ center = bboxes[i, :3]
+ centers = [center]
+ for j in range(period[i]):
+ # draw corners
+ corners = box3d_to_corners(temp_bboxes[:, -1-j])[i, [0, 3, 7, 4, 0]]
+ x = corners[:, 0]
+ y = corners[:, 1]
+ self.axes.plot(x, y, color=color, linewidth=2, linestyle='-')
+
+ # draw line to indicate forward direction
+ forward_center = np.mean(corners[2:4], axis=0)
+ center = np.mean(corners[0:4], axis=0)
+ x = [forward_center[0], center[0]]
+ y = [forward_center[1], center[1]]
+ self.axes.plot(x, y, color=color, linewidth=2, linestyle='-')
+ centers.append(center)
+
+ centers = np.stack(centers)
+ xs = centers[:, 0]
+ ys = centers[:, 1]
+ self.axes.plot(xs, ys, color=color, linewidth=2, linestyle='-')
+
+ def draw_motion_gt(self, data):
+ if not self.plot_choices['motion']:
+ return
+
+ for i in range(data['gt_labels_3d'].shape[0]):
+ label = data['gt_labels_3d'][i]
+ if label == -1:
+ continue
+ color = color_mapping[i % len(color_mapping)]
+ vehicle_id_list = [0, 1, 2, 3, 4, 6, 7]
+ if label in vehicle_id_list:
+ dot_size = 150
+ else:
+ dot_size = 25
+
+ center = data['gt_bboxes_3d'][i, :2]
+ masks = data['gt_agent_fut_masks'][i].astype(bool)
+ if masks[0] == 0:
+ continue
+ trajs = data['gt_agent_fut_trajs'][i][masks]
+ trajs = trajs.cumsum(axis=0) + center
+ trajs = np.concatenate([center.reshape(1, 2), trajs], axis=0)
+
+ self._render_traj(trajs, traj_score=1.0,
+ colormap='winter', dot_size=dot_size)
+
+ def draw_motion_pred(self, result, top_k=3):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['motion'] and "trajs_3d" in result):
+ return
+
+ bboxes = result['boxes_3d']
+ labels = result['labels_3d']
+ for i in range(result['labels_3d'].shape[0]):
+ score = result['scores_3d'][i]
+ if score < SCORE_THRESH:
+ continue
+ label = labels[i]
+ vehicle_id_list = [0, 1, 2, 3, 4, 6, 7]
+ if label in vehicle_id_list:
+ dot_size = 150
+ else:
+ dot_size = 25
+
+ traj_score = result['trajs_score'][i].numpy()
+ traj = result['trajs_3d'][i].numpy()
+ num_modes = len(traj_score)
+ center = bboxes[i, :2][None, None].repeat(num_modes, 1, 1).numpy()
+ traj = np.concatenate([center, traj], axis=1)
+
+ sorted_ind = np.argsort(traj_score)[::-1]
+ sorted_traj = traj[sorted_ind, :, :2]
+ sorted_score = traj_score[sorted_ind]
+ norm_score = np.exp(sorted_score[0])
+
+ for j in range(top_k - 1, -1, -1):
+ viz_traj = sorted_traj[j]
+ traj_score = np.exp(sorted_score[j])/norm_score
+ self._render_traj(viz_traj, traj_score=traj_score,
+ colormap='winter', dot_size=dot_size)
+
+ def draw_map_gt(self, data):
+ if not self.plot_choices['map']:
+ return
+ vectors = data['map_infos']
+ for label, vector_list in vectors.items():
+ color = COLOR_VECTORS[label]
+ for vector in vector_list:
+ pts = vector[:, :2]
+ x = np.array([pt[0] for pt in pts])
+ y = np.array([pt[1] for pt in pts])
+ self.axes.plot(x, y, color=color, linewidth=3, marker='o', linestyle='-', markersize=7)
+
+ def draw_map_pred(self, result):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['map'] and "vectors" in result):
+ return
+
+ for i in range(result['scores'].shape[0]):
+ score = result['scores'][i]
+ if score < MAP_SCORE_THRESH:
+ continue
+ color = COLOR_VECTORS[result['labels'][i]]
+ pts = result['vectors'][i]
+ x = pts[:, 0]
+ y = pts[:, 1]
+ plt.plot(x, y, color=color, linewidth=3, marker='o', linestyle='-', markersize=7)
+
+ def draw_planning_gt(self, data):
+ if not self.plot_choices['planning']:
+ return
+
+ # draw planning gt
+ masks = data['gt_ego_fut_masks'].astype(bool)
+ if masks[0] != 0:
+ plan_traj = data['gt_ego_fut_trajs'][masks]
+ cmd = data['gt_ego_fut_cmd']
+ plan_traj[abs(plan_traj) < 0.01] = 0.0
+ plan_traj = plan_traj.cumsum(axis=0)
+ plan_traj = np.concatenate((np.zeros((1, plan_traj.shape[1])), plan_traj), axis=0)
+ self._render_traj(plan_traj, traj_score=1.0,
+ colormap='autumn', dot_size=50)
+
+ def draw_planning_pred(self, data, result, top_k=3):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['planning'] and "planning" in result):
+ return
+
+ if self.plot_choices['track'] and "ego_anchor_queue" in result:
+ ego_temp_bboxes = result["ego_anchor_queue"]
+ ego_period = result["ego_period"]
+ for j in range(ego_period[0]):
+ # draw corners
+ corners = box3d_to_corners(ego_temp_bboxes[:, -1-j])[0, [0, 3, 7, 4, 0]]
+ x = corners[:, 0]
+ y = corners[:, 1]
+ self.axes.plot(x, y, color='mediumseagreen', linewidth=2, linestyle='-')
+
+ # draw line to indicate forward direction
+ forward_center = np.mean(corners[2:4], axis=0)
+ center = np.mean(corners[0:4], axis=0)
+ x = [forward_center[0], center[0]]
+ y = [forward_center[1], center[1]]
+ self.axes.plot(x, y, color='mediumseagreen', linewidth=2, linestyle='-')
+ # import ipdb; ipdb.set_trace()
+ plan_trajs = result['planning'].cpu().numpy()
+ num_cmd = len(CMD_LIST)
+ num_mode = plan_trajs.shape[1]
+ plan_trajs = np.concatenate((np.zeros((num_cmd, num_mode, 1, 2)), plan_trajs), axis=2)
+ plan_score = result['planning_score'].cpu().numpy()
+
+ cmd = data['gt_ego_fut_cmd'].argmax()
+ plan_trajs = plan_trajs[cmd]
+ plan_score = plan_score[cmd]
+
+ sorted_ind = np.argsort(plan_score)[::-1]
+ sorted_traj = plan_trajs[sorted_ind, :, :2]
+ sorted_score = plan_score[sorted_ind]
+ norm_score = np.exp(sorted_score[0])
+
+ for j in range(top_k - 1, -1, -1):
+ viz_traj = sorted_traj[j]
+ traj_score = np.exp(sorted_score[j]) / norm_score
+ self._render_traj(viz_traj, traj_score=traj_score,
+ colormap='autumn', dot_size=50)
+
+ def _render_traj(
+ self,
+ future_traj,
+ traj_score=1,
+ colormap='winter',
+ points_per_step=20,
+ dot_size=25
+ ):
+ total_steps = (len(future_traj) - 1) * points_per_step + 1
+ dot_colors = matplotlib.colormaps[colormap](
+ np.linspace(0, 1, total_steps))[:, :3]
+ dot_colors = dot_colors * traj_score + \
+ (1 - traj_score) * np.ones_like(dot_colors)
+ total_xy = np.zeros((total_steps, 2))
+ for i in range(total_steps - 1):
+ unit_vec = future_traj[i // points_per_step +
+ 1] - future_traj[i // points_per_step]
+ total_xy[i] = (i / points_per_step - i // points_per_step) * \
+ unit_vec + future_traj[i // points_per_step]
+ total_xy[-1] = future_traj[-1]
+ self.axes.scatter(
+ total_xy[:, 0], total_xy[:, 1], c=dot_colors, s=dot_size)
+
+ def _render_sdc_car(self):
+ sdc_car_png = cv2.imread('resources/sdc_car.png')
+ sdc_car_png = cv2.cvtColor(sdc_car_png, cv2.COLOR_BGR2RGB)
+ im = self.axes.imshow(sdc_car_png, extent=(-1, 1, -2, 2))
+ im.set_zorder(2)
+
+ def _render_legend(self):
+ legend = cv2.imread('resources/legend.png')
+ legend = cv2.cvtColor(legend, cv2.COLOR_BGR2RGB)
+ self.axes.imshow(legend, extent=(15, 40, -40, -30))
+
+ def _render_command(self, data):
+ cmd = data['gt_ego_fut_cmd'].argmax()
+ self.axes.text(-38, -38, CMD_LIST[cmd], fontsize=60)
\ No newline at end of file
diff --git a/tools/visualization/cam_render.py b/tools/visualization/cam_render.py
new file mode 100644
index 0000000..be05651
--- /dev/null
+++ b/tools/visualization/cam_render.py
@@ -0,0 +1,271 @@
+import os
+import numpy as np
+import cv2
+from PIL import Image
+
+import matplotlib
+import matplotlib.pyplot as plt
+from pyquaternion import Quaternion
+from nuscenes.utils.data_classes import Box as NuScenesBox
+from nuscenes.utils.geometry_utils import view_points, box_in_image, BoxVisibility, transform_matrix
+
+from tools.visualization.bev_render import (
+ color_mapping,
+ SCORE_THRESH,
+ MAP_SCORE_THRESH,
+ CMD_LIST
+)
+
+
+CAM_NAMES_NUSC = [
+ 'CAM_FRONT_LEFT',
+ 'CAM_FRONT',
+ 'CAM_FRONT_RIGHT',
+ 'CAM_BACK_RIGHT',
+ 'CAM_BACK',
+ 'CAM_BACK_LEFT',
+]
+CAM_NAMES_NUSC_converter = [
+ 'CAM_FRONT',
+ 'CAM_FRONT_RIGHT',
+ 'CAM_FRONT_LEFT',
+ 'CAM_BACK',
+ 'CAM_BACK_LEFT',
+ 'CAM_BACK_RIGHT',
+]
+
+class CamRender:
+ def __init__(
+ self,
+ plot_choices,
+ out_dir,
+ ):
+ self.plot_choices = plot_choices
+ self.pred_dir = os.path.join(out_dir, "cam_pred")
+ os.makedirs(self.pred_dir, exist_ok=True)
+
+ def reset_canvas(self):
+ plt.close()
+ plt.gca().set_axis_off()
+ plt.axis('off')
+ self.fig, self.axes = plt.subplots(2, 3, figsize=(160 /3 , 20))
+ plt.tight_layout()
+
+ def render(
+ self,
+ data,
+ result,
+ index,
+ ):
+ self.reset_canvas()
+ self.render_image_data(data, index)
+ self.draw_detection_pred(data, result)
+ self.draw_motion_pred(data, result)
+ self.draw_planning_pred(data, result)
+ save_path = os.path.join(self.pred_dir, str(index).zfill(4) + '.jpg')
+ self.save_fig(save_path)
+ return save_path
+
+ def load_image(self, data_path, cam):
+ """Update the axis of the plot with the provided image."""
+ image = np.array(Image.open(data_path))
+ font = cv2.FONT_HERSHEY_SIMPLEX
+ org = (50, 60)
+ fontScale = 2
+ color = (0, 0, 0)
+ thickness = 4
+ return cv2.putText(image, cam, org, font, fontScale, color, thickness, cv2.LINE_AA)
+
+ def update_image(self, image, index, cam):
+ """Render image data for each camera."""
+ ax = self.get_axis(index)
+ ax.imshow(image)
+ plt.axis('off')
+ ax.axis('off')
+ ax.grid(False)
+
+ def get_axis(self, index):
+ """Retrieve the corresponding axis based on the index."""
+ return self.axes[index//3, index % 3]
+
+ def save_fig(self, filename):
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0,
+ hspace=0, wspace=0)
+ plt.margins(0, 0)
+ plt.savefig(filename)
+
+ def render_image_data(self, data, index):
+ """Load and annotate image based on the provided path."""
+ for i, cam in enumerate(CAM_NAMES_NUSC):
+ idx = CAM_NAMES_NUSC_converter.index(cam)
+ img_path = data['img_filename'][idx]
+ image = self.load_image(img_path, cam)
+ self.update_image(image, i, cam)
+
+ def draw_detection_pred(self, data, result):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['det'] and "boxes_3d" in result):
+ return
+
+ bboxes = result['boxes_3d'].numpy()
+ for j, cam in enumerate(CAM_NAMES_NUSC):
+ idx = CAM_NAMES_NUSC_converter.index(cam)
+ cam_intrinsic = data['cam_intrinsic'][idx]
+ lidar2cam = data['lidar2cam']
+ extrinsic = lidar2cam[idx]
+ trans = extrinsic[3, :3]
+ rot = Quaternion(matrix=extrinsic[:3, :3]).inverse
+ imsize = (1600, 900)
+
+ for i in range(result['labels_3d'].shape[0]):
+ score = result['scores_3d'][i]
+ if score < SCORE_THRESH:
+ continue
+ color = color_mapping[result['instance_ids'][i] % len(color_mapping)]
+
+ center = bboxes[i, 0 : 3]
+ box_dims = bboxes[i, 3 : 6]
+ nusc_dims = box_dims[..., [1, 0, 2]]
+ quat = Quaternion(axis=[0, 0, 1], radians=bboxes[i, 6])
+ box = NuScenesBox(
+ center,
+ nusc_dims,
+ quat
+ )
+ box.rotate(rot)
+ box.translate(trans)
+ if box_in_image(box, cam_intrinsic, imsize):
+ box.render(
+ self.axes[j // 3, j % 3],
+ view=cam_intrinsic,
+ normalize=True,
+ colors=(color, color, color),
+ linewidth=4,
+ )
+
+ self.axes[j//3, j % 3].set_xlim(0, imsize[0])
+ self.axes[j//3, j % 3].set_ylim(imsize[1], 0)
+
+ def draw_motion_pred(self, data, result, points_per_step=10):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['motion'] and "trajs_3d" in result):
+ return
+
+ bboxes = result['boxes_3d'].numpy()
+ for j, cam in enumerate(CAM_NAMES_NUSC):
+ idx = CAM_NAMES_NUSC_converter.index(cam)
+ cam_intrinsic = data['cam_intrinsic'][idx]
+ lidar2cam = data['lidar2cam']
+ extrinsic = lidar2cam[idx]
+ trans = extrinsic[3, :3]
+ rot = Quaternion(matrix=extrinsic[:3, :3]).inverse
+ imsize = (1600, 900)
+
+ for i in range(result['labels_3d'].shape[0]):
+ score = result['scores_3d'][i]
+ if score < SCORE_THRESH:
+ continue
+ color = color_mapping[result['instance_ids'][i] % len(color_mapping)]
+
+ traj_score = result['trajs_score'][i].numpy()
+ traj = result['trajs_3d'][i].numpy()
+
+ mode_idx = traj_score.argmax()
+ traj = traj[mode_idx]
+ origin = bboxes[i, :2][None]
+ traj = np.concatenate([origin, traj], axis=0)
+ traj_expand = np.ones((traj.shape[0], 1))
+ traj_expand[:] = bboxes[i, 2] - bboxes[i, 5] / 2
+ traj = np.concatenate([traj, traj_expand], axis=1)
+
+ center = bboxes[i, 0 : 3]
+ box_dims = bboxes[i, 3 : 6]
+ nusc_dims = box_dims[..., [1, 0, 2]]
+ quat = Quaternion(axis=[0, 0, 1], radians=bboxes[i, 6])
+ box = NuScenesBox(
+ center,
+ nusc_dims,
+ quat
+ )
+ box.rotate(rot)
+ box.translate(trans)
+ if not box_in_image(box, cam_intrinsic, imsize):
+ continue
+ traj_points = traj @ extrinsic[:3, :3] + trans
+ self._render_traj(traj_points, cam_intrinsic, j, color=color, s=15)
+
+
+ def draw_planning_pred(self, data, result):
+ if not (self.plot_choices['draw_pred'] and self.plot_choices['planning'] and "planning" in result):
+ return
+ # for j, cam in enumerate(CAM_NAMES_NUSC[1]):
+ # idx = CAM_NAMES_NUSC_converter.index(cam)
+ # cam_intrinsic = data['cam_intrinsic'][idx]
+ # lidar2cam = data['lidar2cam']
+ # extrinsic = lidar2cam[idx]
+ # trans = extrinsic[3, :3]
+ # rot = Quaternion(matrix=extrinsic[:3, :3]).inverse
+ # imsize = (1600, 900)
+
+ # plan_trajs = result['planning'][0].cpu().numpy()
+ # plan_trajs = plan_trajs.reshape(3, -1, 6, 2)
+ # num_cmd = len(CMD_LIST)
+ # num_mode = plan_trajs.shape[1]
+ # plan_trajs = np.concatenate((np.zeros((num_cmd, num_mode, 1, 2)), plan_trajs), axis=2)
+ # plan_trajs = plan_trajs.cumsum(axis=-2)
+ # plan_score = result['planning_score'][0].cpu().numpy()
+ # plan_score = plan_score.reshape(3, -1)
+
+ # cmd = data['gt_ego_fut_cmd'].argmax()
+ # plan_trajs = plan_trajs[cmd]
+ # plan_score = plan_score[cmd]
+
+ # mode_idx = plan_score.argmax()
+ # plan_traj = plan_trajs[mode_idx]
+ # traj_expand = np.ones((plan_traj.shape[0], 1)) * -2
+ # # traj_expand[:] = bboxes[i, 2] - bboxes[i, 5] / 2
+ # plan_traj = np.concatenate([plan_traj, traj_expand], axis=1)
+
+ # traj_points = plan_traj @ extrinsic[:3, :3] + trans
+ # self._render_traj(traj_points, cam_intrinsic, j)
+
+ idx = 0 ## front camera
+ cam_intrinsic = data['cam_intrinsic'][idx]
+ lidar2cam = data['lidar2cam']
+ extrinsic = lidar2cam[idx]
+ trans = extrinsic[3, :3]
+ rot = Quaternion(matrix=extrinsic[:3, :3]).inverse
+ # plan_trajs = result['planning'][0].cpu().numpy()
+ # plan_trajs = plan_trajs.reshape(3, -1, 6, 2)
+ # num_cmd = len(CMD_LIST)
+ # num_mode = plan_trajs.shape[1]
+ # plan_trajs = np.concatenate((np.zeros((num_cmd, num_mode, 1, 2)), plan_trajs), axis=2)
+ # plan_trajs = plan_trajs.cumsum(axis=-2)
+ # plan_score = result['planning_score'][0].cpu().numpy()
+ # plan_score = plan_score.reshape(3, -1)
+
+ # cmd = data['gt_ego_fut_cmd'].argmax()
+ # plan_trajs = plan_trajs[cmd]
+ # plan_score = plan_score[cmd]
+
+ # mode_idx = plan_score.argmax()
+ # plan_traj = plan_trajs[mode_idx]
+ plan_traj = result["final_planning"]
+ plan_traj = np.concatenate((np.zeros((1, 2)), plan_traj), axis=0)
+ traj_expand = np.ones((plan_traj.shape[0], 1)) * -1.8
+ plan_traj = np.concatenate([plan_traj, traj_expand], axis=1)
+
+ traj_points = plan_traj @ extrinsic[:3, :3] + trans
+ self._render_traj(traj_points, cam_intrinsic, j=1)
+
+ def _render_traj(self, traj_points, cam_intrinsic, j, color=(1, 0.5, 0), s=150, points_per_step=10):
+ total_steps = (len(traj_points)-1) * points_per_step + 1
+ total_xy = np.zeros((total_steps, 3))
+ for k in range(total_steps-1):
+ unit_vec = traj_points[k//points_per_step +
+ 1] - traj_points[k//points_per_step]
+ total_xy[k] = (k/points_per_step - k//points_per_step) * \
+ unit_vec + traj_points[k//points_per_step]
+ in_range_mask = total_xy[:, 2] > 0.1
+ traj_points = view_points(
+ total_xy.T, cam_intrinsic, normalize=True)[:2, :]
+ traj_points = traj_points[:2, in_range_mask]
+ self.axes[j // 3, j % 3].scatter(traj_points[0], traj_points[1], color=color, s=s)
\ No newline at end of file
diff --git a/tools/visualization/visualize.py b/tools/visualization/visualize.py
new file mode 100644
index 0000000..591908a
--- /dev/null
+++ b/tools/visualization/visualize.py
@@ -0,0 +1,110 @@
+import os
+import glob
+import argparse
+from tqdm import tqdm
+
+import cv2
+import numpy as np
+from PIL import Image
+
+import mmcv
+from mmcv import Config
+from mmdet.datasets import build_dataset
+
+from tools.visualization.bev_render import BEVRender
+from tools.visualization.cam_render import CamRender
+
+plot_choices = dict(
+ draw_pred = True, # True: draw gt and pred; False: only draw gt
+ det = True,
+ track = True, # True: draw history tracked boxes
+ motion = True,
+ map = True,
+ planning = True,
+)
+START = 0
+END = 81
+INTERVAL = 1
+
+
+class Visualizer:
+ def __init__(
+ self,
+ args,
+ plot_choices,
+ ):
+ self.out_dir = args.out_dir
+ self.combine_dir = os.path.join(self.out_dir, 'combine')
+ os.makedirs(self.combine_dir, exist_ok=True)
+
+ cfg = Config.fromfile(args.config)
+ self.dataset = build_dataset(cfg.data.val)
+ self.results = mmcv.load(args.result_path)
+ self.bev_render = BEVRender(plot_choices, self.out_dir)
+ self.cam_render = CamRender(plot_choices, self.out_dir)
+
+ def add_vis(self, index):
+ data = self.dataset.get_data_info(index)
+ result = self.results[index]['img_bbox']
+
+ bev_gt_path, bev_pred_path = self.bev_render.render(data, result, index)
+ cam_pred_path = self.cam_render.render(data, result, index)
+ self.combine(bev_gt_path, bev_pred_path, cam_pred_path, index)
+
+ def combine(self, bev_gt_path, bev_pred_path, cam_pred_path, index):
+ bev_gt = cv2.imread(bev_gt_path)
+ bev_image = cv2.imread(bev_pred_path)
+ cam_image = cv2.imread(cam_pred_path)
+ merge_image = cv2.hconcat([cam_image, bev_image, bev_gt])
+ save_path = os.path.join(self.combine_dir, str(index).zfill(4) + '.jpg')
+ cv2.imwrite(save_path, merge_image)
+
+ def image2video(self, fps=12, downsample=4):
+ imgs_path = glob.glob(os.path.join(self.combine_dir, '*.jpg'))
+ imgs_path = sorted(imgs_path)
+ img_array = []
+ for img_path in tqdm(imgs_path):
+ img = cv2.imread(img_path)
+ height, width, channel = img.shape
+ img = cv2.resize(img, (width//downsample, height //
+ downsample), interpolation=cv2.INTER_AREA)
+ height, width, channel = img.shape
+ size = (width, height)
+ img_array.append(img)
+ out_path = os.path.join(self.out_dir, 'video.mp4')
+ out = cv2.VideoWriter(
+ out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size)
+ for i in range(len(img_array)):
+ out.write(img_array[i])
+ out.release()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Visualize groundtruth and results')
+ parser.add_argument('config', help='config file path')
+ parser.add_argument('--result-path',
+ default=None,
+ help='prediction result to visualize'
+ 'If submission file is not provided, only gt will be visualized')
+ parser.add_argument(
+ '--out-dir',
+ default='vis',
+ help='directory where visualize results will be saved')
+ args = parser.parse_args()
+
+ return args
+
+def main():
+ args = parse_args()
+ visualizer = Visualizer(args, plot_choices)
+
+ for idx in tqdm(range(START, END, INTERVAL)):
+ if idx > len(visualizer.results):
+ break
+ visualizer.add_vis(idx)
+
+ visualizer.image2video()
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file