Skip to content

Commit 12cd095

Browse files
committed
add tensorrt
1 parent aa3876b commit 12cd095

File tree

18 files changed

+815
-174
lines changed

18 files changed

+815
-174
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,5 @@ play.py
110110
preprocess_data.py
111111
res/
112112
adj.md
113+
tensorrt/build/*
113114

README.md

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,23 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV
44

55

66
The mIOU evaluation result of the models trained and evaluated on cityscapes train/val set is:
7-
| none | ss | ssc | msf | mscf | fps | link |
7+
| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link |
88
|------|:--:|:---:|:---:|:----:|:---:|:----:|
9-
| bisenetv1 | 74.85 | 76.46 | 77.36 | 78.72 | - | [download](https://drive.google.com/file/d/1e1_E7OrpjTaD5Rael7Fus5lg-uGZ5TUZ/view?usp=sharing) |
10-
| bisenetv2 | 74.39 | 74.44 | 76.10 | 75.94 | - | [download](https://drive.google.com/file/d/1r_F-KZg-3s2pPcHRIuHZhZ0DQ0wocudk/view?usp=sharing) |
9+
| bisenetv1 | 75.55 | 76.90 | 77.40 | 78.91 | 60/19 | [download](https://drive.google.com/file/d/140MBBAt49N1z1wsKueoFA6HB_QuYud8i/view?usp=sharing) |
10+
| bisenetv2 | 74.12 | 74.18 | 75.89 | 75.87 | 50/16 | [download](https://drive.google.com/file/d/1qq38u9JT4pp1ubecGLTCHHtqwntH0FCY/view?usp=sharing) |
1111

1212
> Where **ss** means single scale evaluation, **ssc** means single scale crop evaluation, **msf** means multi-scale evaluation with flip augment, and **mscf** means multi-scale crop evaluation with flip evaluation. The eval scales of multi-scales evaluation are `[0.5, 0.75, 1.0, 1.25, 1.5, 1.75]`, and the crop size of crop evaluation is `[1024, 1024]`.
1313
14+
> The fps is tested in different way from the paper. For more information, please see [here](./tensorrt).
15+
1416
Note that the model has a big variance, which means that the results of training for many times would vary within a relatively big margin. For example, if you train bisenetv2 for many times, you will observe that the result of **ss** evaluation of bisenetv2 varies between 72.1-74.4.
1517

1618

1719
## platform
1820
My platform is like this:
19-
* ubuntu 16.04
20-
* cuda 10.1.243
21+
* ubuntu 18.04
22+
* nvidia Tesla T4 gpu, driver 450.51.05
23+
* cuda 10.2
2124
* cudnn 7
2225
* miniconda python 3.6.9
2326
* pytorch 1.6.0
@@ -59,7 +62,12 @@ Then you need to change the field of `im_root` and `train/val_im_anns` in the co
5962
In order to train the model, you can run command like this:
6063
```
6164
$ export CUDA_VISIBLE_DEVICES=0,1
65+
66+
# if you want to train with apex
6267
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --model bisenetv2 # or bisenetv1
68+
69+
# if you want to train with pytorch fp16 feature from torch 1.6
70+
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --model bisenetv2 # or bisenetv1
6371
```
6472

6573
Note that though `bisenetv2` has fewer flops, it requires much more training iterations. The the training time of `bisenetv1` is shorter.
@@ -70,6 +78,9 @@ You can also load the trained model weights and finetune from it:
7078
```
7179
$ export CUDA_VISIBLE_DEVICES=0,1
7280
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1
81+
82+
# same with pytorch fp16 feature
83+
$ python -m torch.distributed.launch --nproc_per_node=2 tools/train_amp.py --finetune-from ./res/model_final.pth --model bisenetv2 # or bisenetv1
7384
```
7485

7586

@@ -79,6 +90,10 @@ You can also evaluate a trained model like this:
7990
$ python tools/evaluate.py --model bisenetv1 --weight-path /path/to/your/weight.pth
8091
```
8192

93+
## Infer with tensorrt
94+
You can go to [tensorrt](./tensorrt) For details.
95+
96+
8297
### Be aware that this is the refactored version of the original codebase. You can go to the `old` directory for original implementation.
8398

8499

datasets/cityscapes/gtFine

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
/data2/zzy/.datasets/cityscapes/gtFine/
1+
/data2/zzy/.datasets/cityscapes//gtFine/

datasets/cityscapes/leftImg8bit

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
/data2/zzy/.datasets/cityscapes/leftImg8bit/
1+
/data2/zzy/.datasets/cityscapes//leftImg8bit/

lib/base_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, dataroot, annpath, trans_func=None, mode='train'):
4040

4141
def __getitem__(self, idx):
4242
impth, lbpth = self.img_paths[idx], self.lb_paths[idx]
43-
img, label = cv2.imread(impth), cv2.imread(lbpth, 0)
43+
img, label = cv2.imread(impth)[:, :, ::-1], cv2.imread(lbpth, 0)
4444
if not self.lb_map is None:
4545
label = self.lb_map[label]
4646
im_lb = dict(im=img, lb=label)

lib/models/bisenetv1.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
class ConvBNReLU(nn.Module):
16+
1617
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
1718
super(ConvBNReLU, self).__init__()
1819
self.conv = nn.Conv2d(in_chan,
@@ -38,16 +39,39 @@ def init_weight(self):
3839
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
3940

4041

42+
class UpSample(nn.Module):
43+
44+
def __init__(self, n_chan, factor=2):
45+
super(UpSample, self).__init__()
46+
out_chan = n_chan * factor * factor
47+
self.proj = nn.Conv2d(n_chan, out_chan, 1, 1, 0)
48+
self.up = nn.PixelShuffle(factor)
49+
self.init_weight()
50+
51+
def forward(self, x):
52+
feat = self.proj(x)
53+
feat = self.up(feat)
54+
return feat
55+
56+
def init_weight(self):
57+
nn.init.xavier_normal_(self.proj.weight, gain=1.)
58+
59+
4160
class BiSeNetOutput(nn.Module):
42-
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
61+
62+
def __init__(self, in_chan, mid_chan, n_classes, up_factor=32, *args, **kwargs):
4363
super(BiSeNetOutput, self).__init__()
64+
self.up_factor = up_factor
65+
out_chan = n_classes * up_factor * up_factor
4466
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
45-
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
67+
self.conv_out = nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=True)
68+
self.up = nn.PixelShuffle(up_factor)
4669
self.init_weight()
4770

4871
def forward(self, x):
4972
x = self.conv(x)
5073
x = self.conv_out(x)
74+
x = self.up(x)
5175
return x
5276

5377
def init_weight(self):
@@ -79,7 +103,7 @@ def __init__(self, in_chan, out_chan, *args, **kwargs):
79103

80104
def forward(self, x):
81105
feat = self.conv(x)
82-
atten = F.avg_pool2d(feat, feat.size()[2:])
106+
atten = torch.mean(feat, dim=(2, 3), keepdim=True)
83107
atten = self.conv_atten(atten)
84108
atten = self.bn_atten(atten)
85109
atten = self.sigmoid_atten(atten)
@@ -102,28 +126,25 @@ def __init__(self, *args, **kwargs):
102126
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
103127
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
104128
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
129+
self.up32 = nn.Upsample(scale_factor=2.)
130+
self.up16 = nn.Upsample(scale_factor=2.)
105131

106132
self.init_weight()
107133

108134
def forward(self, x):
109-
H0, W0 = x.size()[2:]
110135
feat8, feat16, feat32 = self.resnet(x)
111-
H8, W8 = feat8.size()[2:]
112-
H16, W16 = feat16.size()[2:]
113-
H32, W32 = feat32.size()[2:]
114136

115-
avg = F.avg_pool2d(feat32, feat32.size()[2:])
137+
avg = torch.mean(feat32, dim=(2, 3), keepdim=True)
116138
avg = self.conv_avg(avg)
117-
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
118139

119140
feat32_arm = self.arm32(feat32)
120-
feat32_sum = feat32_arm + avg_up
121-
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
141+
feat32_sum = feat32_arm + avg
142+
feat32_up = self.up32(feat32_sum)
122143
feat32_up = self.conv_head32(feat32_up)
123144

124145
feat16_arm = self.arm16(feat16)
125146
feat16_sum = feat16_arm + feat32_up
126-
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
147+
feat16_up = self.up16(feat16_sum)
127148
feat16_up = self.conv_head16(feat16_up)
128149

129150
return feat16_up, feat32_up # x8, x16
@@ -203,7 +224,7 @@ def __init__(self, in_chan, out_chan, *args, **kwargs):
203224
def forward(self, fsp, fcp):
204225
fcat = torch.cat([fsp, fcp], dim=1)
205226
feat = self.convblk(fcat)
206-
atten = F.avg_pool2d(feat, feat.size()[2:])
227+
atten = torch.mean(feat, dim=(2, 3), keepdim=True)
207228
atten = self.conv1(atten)
208229
atten = self.relu(atten)
209230
atten = self.conv2(atten)
@@ -231,14 +252,17 @@ def get_params(self):
231252

232253

233254
class BiSeNetV1(nn.Module):
234-
def __init__(self, n_classes, *args, **kwargs):
255+
256+
def __init__(self, n_classes, output_aux=True, *args, **kwargs):
235257
super(BiSeNetV1, self).__init__()
236258
self.cp = ContextPath()
237259
self.sp = SpatialPath()
238260
self.ffm = FeatureFusionModule(256, 256)
239-
self.conv_out = BiSeNetOutput(256, 256, n_classes)
240-
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
241-
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
261+
self.conv_out = BiSeNetOutput(256, 256, n_classes, up_factor=8)
262+
self.output_aux = output_aux
263+
if self.output_aux:
264+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes, up_factor=8)
265+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes, up_factor=16)
242266
self.init_weight()
243267

244268
def forward(self, x):
@@ -248,13 +272,12 @@ def forward(self, x):
248272
feat_fuse = self.ffm(feat_sp, feat_cp8)
249273

250274
feat_out = self.conv_out(feat_fuse)
251-
feat_out16 = self.conv_out16(feat_cp8)
252-
feat_out32 = self.conv_out32(feat_cp16)
253-
254-
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
255-
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
256-
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
257-
return feat_out, feat_out16, feat_out32
275+
if self.output_aux:
276+
feat_out16 = self.conv_out16(feat_cp8)
277+
feat_out32 = self.conv_out32(feat_cp16)
278+
return feat_out, feat_out16, feat_out32
279+
feat_out = feat_out.argmax(dim=1)
280+
return feat_out
258281

259282
def init_weight(self):
260283
for ly in self.children():
@@ -276,11 +299,13 @@ def get_params(self):
276299

277300

278301
if __name__ == "__main__":
279-
net = BiSeNet(19)
302+
net = BiSeNetV1(19)
280303
net.cuda()
281304
net.eval()
282305
in_ten = torch.randn(16, 3, 640, 480).cuda()
283306
out, out16, out32 = net(in_ten)
284307
print(out.shape)
308+
print(out16.shape)
309+
print(out32.shape)
285310

286311
net.get_params()

0 commit comments

Comments
 (0)