Skip to content

Commit 8a64de5

Browse files
committed
Add support for BEVFusion
1 parent c5dfdd7 commit 8a64de5

File tree

20 files changed

+2444
-5
lines changed

20 files changed

+2444
-5
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .convfuser import ConvFuser
2+
__all__ = {
3+
'ConvFuser':ConvFuser
4+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import torch
2+
from torch import nn
3+
4+
5+
class ConvFuser(nn.Module):
6+
def __init__(self,model_cfg) -> None:
7+
super().__init__()
8+
self.model_cfg = model_cfg
9+
in_channel = self.model_cfg.IN_CHANNEL
10+
out_channel = self.model_cfg.OUT_CHANNEL
11+
self.conv = nn.Sequential(
12+
nn.Conv2d(in_channel, out_channel, 3, padding=1, bias=False),
13+
nn.BatchNorm2d(out_channel),
14+
nn.ReLU(True)
15+
)
16+
17+
def forward(self,batch_dict):
18+
"""
19+
Args:
20+
batch_dict:
21+
spatial_features_img (tensor): Bev features from image modality
22+
spatial_features (tensor): Bev features from lidar modality
23+
24+
Returns:
25+
batch_dict:
26+
spatial_features (tensor): Bev features after muli-modal fusion
27+
"""
28+
img_bev = batch_dict['spatial_features_img']
29+
lidar_bev = batch_dict['spatial_features']
30+
cat_bev = torch.cat([img_bev,lidar_bev],dim=1)
31+
mm_bev = self.conv(cat_bev)
32+
batch_dict['spatial_features'] = mm_bev
33+
return batch_dict
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .swin import SwinTransformer
2+
__all__ = {
3+
'SwinTransformer':SwinTransformer,
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .generalized_lss import GeneralizedLSSFPN
2+
__all__ = {
3+
'GeneralizedLSSFPN':GeneralizedLSSFPN,
4+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from ...model_utils.basic_block_2d import BasicBlock2D
5+
6+
7+
class GeneralizedLSSFPN(nn.Module):
8+
"""
9+
This module implements FPN, which creates pyramid features built on top of some input feature maps.
10+
This code is adapted from https://github.com/open-mmlab/mmdetection/blob/main/mmdet/models/necks/fpn.py with minimal modifications.
11+
"""
12+
def __init__(self, model_cfg):
13+
super().__init__()
14+
self.model_cfg = model_cfg
15+
in_channels = self.model_cfg.IN_CHANNELS
16+
out_channels = self.model_cfg.OUT_CHANNELS
17+
num_ins = len(in_channels)
18+
num_outs = self.model_cfg.NUM_OUTS
19+
start_level = self.model_cfg.START_LEVEL
20+
end_level = self.model_cfg.END_LEVEL
21+
22+
self.in_channels = in_channels
23+
24+
if end_level == -1:
25+
self.backbone_end_level = num_ins - 1
26+
else:
27+
self.backbone_end_level = end_level
28+
assert end_level <= len(in_channels)
29+
assert num_outs == end_level - start_level
30+
self.start_level = start_level
31+
self.end_level = end_level
32+
33+
self.lateral_convs = nn.ModuleList()
34+
self.fpn_convs = nn.ModuleList()
35+
36+
for i in range(self.start_level, self.backbone_end_level):
37+
l_conv = BasicBlock2D(
38+
in_channels[i] + (in_channels[i + 1] if i == self.backbone_end_level - 1 else out_channels),
39+
out_channels, kernel_size=1, bias = False
40+
)
41+
fpn_conv = BasicBlock2D(out_channels,out_channels, kernel_size=3, padding=1, bias = False)
42+
self.lateral_convs.append(l_conv)
43+
self.fpn_convs.append(fpn_conv)
44+
45+
def forward(self, batch_dict):
46+
"""
47+
Args:
48+
batch_dict:
49+
image_features (list[tensor]): Multi-stage features from image backbone.
50+
Returns:
51+
batch_dict:
52+
image_fpn (list(tensor)): FPN features.
53+
"""
54+
# upsample -> cat -> conv1x1 -> conv3x3
55+
inputs = batch_dict['image_features']
56+
assert len(inputs) == len(self.in_channels)
57+
58+
# build laterals
59+
laterals = [inputs[i + self.start_level] for i in range(len(inputs))]
60+
61+
# build top-down path
62+
used_backbone_levels = len(laterals) - 1
63+
for i in range(used_backbone_levels - 1, -1, -1):
64+
x = F.interpolate(
65+
laterals[i + 1],
66+
size=laterals[i].shape[2:],
67+
mode='bilinear', align_corners=False,
68+
)
69+
laterals[i] = torch.cat([laterals[i], x], dim=1)
70+
laterals[i] = self.lateral_convs[i](laterals[i])
71+
laterals[i] = self.fpn_convs[i](laterals[i])
72+
73+
# build outputs
74+
outs = [laterals[i] for i in range(used_backbone_levels)]
75+
batch_dict['image_fpn'] = tuple(outs)
76+
return batch_dict

0 commit comments

Comments
 (0)