Skip to content

Commit 68a6ec1

Browse files
committed
main
1 parent 4a7fd5b commit 68a6ec1

File tree

117 files changed

+9947
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

117 files changed

+9947
-0
lines changed

README.md

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# TreeFilter-Torch
2+
This project provides a cuda implementation for "[Learnable Tree Filter for Structure-preserving
3+
Feature Transform](https://megvii-my.sharepoint.cn/:b:/g/personal/songlin_megvii_com/EfbrITIdvqBCu-SaW9gZOHQBFIkcIisB6-FyO9SzzrZyPQ?e=YI06YP)" on PyTorch. Multiple semantic segmentation experiments are reproduced to verify the effectiveness of tree filtering module on PASCAL VOC2012 and Cityscapes. For the reason that the experiments in the paper were conducted using internal framework, this project reimplements them on PyTorch and reports detailed comparisons below. In addition, many thanks to [TorchSeg](https://github.com/ycszen/TorchSeg).
4+
5+
![introduce image](demo/introduce.png)
6+
7+
## Prerequisites
8+
- PyTorch 1.2
9+
- `sudo pip3 install torch torchvision`
10+
- Easydict
11+
- `sudo pip3 install easydict`
12+
- Apex
13+
- `https://nvidia.github.io/apex/index.html`
14+
- Ninja
15+
- `sudo apt-get install ninja-build`
16+
- tqdm
17+
- `sudo pip3 install tqdm`
18+
- Boost (optional for Prim and Kruskal algorithm)
19+
- `sudo apt-get install libboost-dev`
20+
21+
## Installation
22+
### Building from source
23+
- `git clone https://github.com/StevenGrove/TreeFilter-Seg`
24+
- `cd TreeFilter-Seg/furnace/kernels/lib_tree_filter`
25+
- `sudo python3 setup.py build develop`
26+
27+
This project implements three well-known algorithms of minimal spanning tree, i.e., Boruvka, Kruskal and Prim. The default algorithm is set to *Boruvka* for its linear computational complexity in the plain graph. The user can change the configuration in the source file "lib_tree_filter/src/mst/mst.cu" .
28+
29+
## Pretrained Model
30+
- ResNet-50 [GoogleDrive](https://drive.google.com/open?id=1tRO4SUL0rdjXbKcyp1CQ6SefkL9QtX1b)
31+
- ResNet-101 [GoogleDrive](https://drive.google.com/open?id=11t0f0FcLOPj7KvHYdIAGANNbbWU_fJ1d)
32+
33+
## Performance and Benchmarks
34+
### Notes
35+
FCN-32d: FCN with decoder whose maximum stride is 32;
36+
Extra: Global average pooling + ResBlock;
37+
TF: Learnable tree filtering module;
38+
SS: Single-scale;
39+
MSF: Multi-scale + Flip.
40+
41+
### PASCAL VOC 2012 *val* set
42+
Methods | Backbone | mIoU (ss) | Acc (ss) | mIoU (msf) | Acc (msf) | Model
43+
:--:|:--:|:--:|:--:|:--:|:--:|:--:
44+
FCN-32d | R50_v1c | 71.82% | 93.62% | 73.96% | 94.14% | [GoogleDrive](https://drive.google.com/open?id=1Wzdhfa1mh_JFcqvLKPs7dXWgkOCqTnoH)
45+
FCN-32d+TF | R50_v1c | 76.31% | 94.57% | 77.80% | 94.96% | [GoogleDrive](https://drive.google.com/open?id=19wwP7KW8aCWjyd21zGLhrMz2g3lW9o9Z)
46+
FCN-32d | R101_v1c | 74.53% | 94.29% | 76.08% | 94.63% | [GoogleDrive](https://drive.google.com/open?id=19HQYK5JMS2bw2CbkTmG0VypnYfT3p-NN)
47+
FCN-32d+TF | R101_v1c | 77.82% | 94.92% | 79.22% | 95.22% | [GoogleDrive](https://drive.google.com/open?id=1HywWQn-sHR9iddHTiLyYHNsH3TClvQMo)
48+
FCN-32d+Extra | R101_v1c | 78.04% | 95.01% | 79.69% | 95.41% | [GoogleDrive](https://drive.google.com/open?id=1dzag3GVcY9k-6ExOb1B4zQqfmtt4PbBy)
49+
FCN-32d+Extra+TF | R101_v1c | 79.81% | 95.38% | 80.97% | 95.67% | [GoogleDrive](https://drive.google.com/open?id=1sfZyuL2pikmhWLRw9-XbrJpZJEbjawh6)
50+
FCN-32d+Extra+TF<sup>*</sup> | R101_v1c | 80.32% | 95.66% | 82.28% | 96.01% | [GoogleDrive](https://drive.google.com/open?id=19FpTs6NtfJLsLwN_03U4A2zTpPSIAnfS)
51+
52+
<sup>*</sup> further finetuned on the original train set
53+
54+
### Cityscapes *val* set
55+
Methods | Backbone | mIoU (ss) | Acc (ss) | mIoU (msf) | Acc (msf) | Model
56+
:--:|:--:|:--:|:--:|:--:|:--:|:--:
57+
FCN-32d+Extra | R101_v1c | 78.29% | 96.09% | 79.40% | 96.27% | [GoogleDrive](https://drive.google.com/open?id=1MT4-ZzuCTNgfpRHGG6fT4TkUVtFuuJO5)
58+
FCN-32d+Extra+TF | R101_v1c | 79.58% | 96.31% | 80.85% | 96.46% | [GoogleDrive](https://drive.google.com/open?id=1yXPEUrIZ1CfFk7-1YHgMlDNz81Fhp1kz)
59+
60+
## Usage
61+
As in the original TorchSeg, distributed training is recommended for either single machine or multiple machines.
62+
For detailed usage, please refer to the [Training](https://github.com/ycszen/TorchSeg#Training) and [Inference](https://github.com/ycszen/TorchSeg#Inference) sections in TorchSeg.
63+
64+
## To do
65+
- [ ] Experiments on ADE20K
66+
- [ ] Visualization of tree filter
67+
- [ ] Additional tasks
68+
- [ ] Object detection
69+
- [ ] Instance segmentation
70+
- [ ] Optical flow
71+
72+
## Citation
73+
74+
Please cite the learnable tree filter in your publications if it helps your research.
75+
76+
```
77+
The pre-printed version has been submitted to Arxiv and is awaiting public.
78+
```
79+
80+
Please cite this project in your publications if it helps your research.
81+
```
82+
@misc{treefilter-torch,
83+
author = {Song, Lin},
84+
title = {TreeFiler-Torch},
85+
howpublished = {\url{https://github.com/StevenGrove/TreeFilter-Torch}},
86+
year = {2019}
87+
}
88+
```
89+

demo/introduce.png

640 KB
Loading

furnace/__init__.py

Whitespace-only changes.

furnace/base_model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .resnet import ResNet, resnet18, resnet34, resnet50, resnet101, resnet152
2+
from .xception import Xception, xception39

furnace/base_model/resnet.py

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
import functools
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from utils.pyt_utils import load_model
6+
7+
8+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
9+
'resnet152']
10+
11+
12+
def conv3x3(in_planes, out_planes, stride=1):
13+
"""3x3 convolution with padding"""
14+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15+
padding=1, bias=False)
16+
17+
18+
class BasicBlock(nn.Module):
19+
expansion = 1
20+
21+
def __init__(self, inplanes, planes, stride=1, norm_layer=None,
22+
bn_eps=1e-5, bn_momentum=0.1, downsample=None, inplace=True,
23+
has_relu=True):
24+
super(BasicBlock, self).__init__()
25+
self.conv1 = conv3x3(inplanes, planes, stride)
26+
self.bn1 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
27+
self.relu = nn.ReLU(inplace=inplace)
28+
self.relu_inplace = nn.ReLU(inplace=True)
29+
self.conv2 = conv3x3(planes, planes)
30+
self.bn2 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
31+
self.downsample = downsample
32+
if downsample is None and inplace != planes:
33+
self.downsample = nn.Sequential(
34+
nn.Conv2d(inplanes, planes,
35+
kernel_size=1, stride=stride, bias=False),
36+
norm_layer(planes, eps=bn_eps,momentum=bn_momentum))
37+
self.stride = stride
38+
self.inplace = inplace
39+
self.has_relu = has_relu
40+
41+
def forward(self, x):
42+
residual = x
43+
44+
out = self.conv1(x)
45+
out = self.bn1(out)
46+
out = self.relu(out)
47+
48+
out = self.conv2(out)
49+
out = self.bn2(out)
50+
51+
if self.downsample is not None:
52+
residual = self.downsample(x)
53+
54+
if self.inplace:
55+
out += residual
56+
else:
57+
out = out + residual
58+
59+
if self.has_relu:
60+
out = self.relu_inplace(out)
61+
62+
return out
63+
64+
class ResBlock(nn.Module):
65+
def __init__(self, in_channels, out_channels, stride=1,
66+
expansion=2, norm_layer=None, bn_eps=1e-5,
67+
bn_momentum=0.1, has_relu=True, has_bias=False):
68+
super(ResBlock, self).__init__()
69+
self.has_relu = has_relu
70+
mid_channels = out_channels // expansion
71+
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=has_bias)
72+
self.bn1 = norm_layer(mid_channels, eps=bn_eps, momentum=bn_momentum)
73+
self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride,
74+
padding=1, bias=has_bias)
75+
self.bn2 = norm_layer(mid_channels, eps=bn_eps, momentum=bn_momentum)
76+
self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=has_bias)
77+
self.bn3 = norm_layer(out_channels, eps=bn_eps, momentum=bn_momentum)
78+
if in_channels != out_channels:
79+
self.down_sampler = nn.Sequential(
80+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
81+
norm_layer(out_channels, eps=bn_eps,momentum=bn_momentum))
82+
else:
83+
self.down_sampler = None
84+
85+
def forward(self, x):
86+
residual = x
87+
88+
out = self.conv1(x)
89+
out = self.bn1(out)
90+
out = F.relu(out)
91+
92+
out = self.conv2(out)
93+
out = self.bn2(out)
94+
out = F.relu(out)
95+
96+
out = self.conv3(out)
97+
out = self.bn3(out)
98+
99+
if self.down_sampler is not None:
100+
residual = self.down_sampler(x)
101+
102+
out += residual
103+
if self.has_relu:
104+
out = F.relu(out, inplace=True)
105+
106+
return out
107+
108+
class Bottleneck(nn.Module):
109+
expansion = 4
110+
111+
def __init__(self, inplanes, planes, stride=1,
112+
norm_layer=None, bn_eps=1e-5, bn_momentum=0.1,
113+
downsample=None, inplace=True, has_relu=True):
114+
super(Bottleneck, self).__init__()
115+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
116+
self.bn1 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
117+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
118+
padding=1, bias=False)
119+
self.bn2 = norm_layer(planes, eps=bn_eps, momentum=bn_momentum)
120+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
121+
bias=False)
122+
self.bn3 = norm_layer(planes * self.expansion, eps=bn_eps,
123+
momentum=bn_momentum)
124+
self.has_relu = has_relu
125+
self.relu = nn.ReLU(inplace=inplace)
126+
self.relu_inplace = nn.ReLU(inplace=True)
127+
self.downsample = downsample
128+
self.stride = stride
129+
self.inplace = inplace
130+
131+
def forward(self, x):
132+
residual = x
133+
134+
out = self.conv1(x)
135+
out = self.bn1(out)
136+
out = self.relu(out)
137+
138+
out = self.conv2(out)
139+
out = self.bn2(out)
140+
out = self.relu(out)
141+
142+
out = self.conv3(out)
143+
out = self.bn3(out)
144+
145+
if self.downsample is not None:
146+
residual = self.downsample(x)
147+
148+
if self.inplace:
149+
out += residual
150+
else:
151+
out = out + residual
152+
if self.has_relu:
153+
out = self.relu_inplace(out)
154+
155+
return out
156+
157+
158+
class ResNet(nn.Module):
159+
160+
def __init__(self, block, layers, norm_layer=nn.BatchNorm2d, bn_eps=1e-5,
161+
bn_momentum=0.1, deep_stem=False, stem_width=32, inplace=True):
162+
self.inplanes = stem_width * 2 if deep_stem else 64
163+
super(ResNet, self).__init__()
164+
if deep_stem:
165+
self.conv1 = nn.Sequential(
166+
nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1,
167+
bias=False),
168+
norm_layer(stem_width, eps=bn_eps, momentum=bn_momentum),
169+
nn.ReLU(inplace=inplace),
170+
nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1,
171+
padding=1,
172+
bias=False),
173+
norm_layer(stem_width, eps=bn_eps, momentum=bn_momentum),
174+
nn.ReLU(inplace=inplace),
175+
nn.Conv2d(stem_width, stem_width * 2, kernel_size=3, stride=1,
176+
padding=1,
177+
bias=False),
178+
)
179+
else:
180+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
181+
bias=False)
182+
183+
self.bn1 = norm_layer(stem_width * 2 if deep_stem else 64, eps=bn_eps,
184+
momentum=bn_momentum)
185+
self.relu = nn.ReLU(inplace=inplace)
186+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
187+
self.layer1 = self._make_layer(block, norm_layer, 64, layers[0],
188+
inplace,
189+
bn_eps=bn_eps, bn_momentum=bn_momentum)
190+
self.layer2 = self._make_layer(block, norm_layer, 128, layers[1],
191+
inplace, stride=2,
192+
bn_eps=bn_eps, bn_momentum=bn_momentum)
193+
self.layer3 = self._make_layer(block, norm_layer, 256, layers[2],
194+
inplace, stride=2,
195+
bn_eps=bn_eps, bn_momentum=bn_momentum)
196+
self.layer4 = self._make_layer(block, norm_layer, 512, layers[3],
197+
inplace, stride=2,
198+
bn_eps=bn_eps, bn_momentum=bn_momentum)
199+
self.layer_channel_nums = (256, 512, 1024, 2048)
200+
201+
def _make_layer(self, block, norm_layer, planes, blocks, inplace=True,
202+
stride=1, bn_eps=1e-5, bn_momentum=0.1):
203+
downsample = None
204+
if stride != 1 or self.inplanes != planes * block.expansion:
205+
downsample = nn.Sequential(
206+
nn.Conv2d(self.inplanes, planes * block.expansion,
207+
kernel_size=1, stride=stride, bias=False),
208+
norm_layer(planes * block.expansion, eps=bn_eps,
209+
momentum=bn_momentum),
210+
)
211+
212+
layers = []
213+
layers.append(block(self.inplanes, planes, stride, norm_layer, bn_eps,
214+
bn_momentum, downsample, inplace))
215+
self.inplanes = planes * block.expansion
216+
for i in range(1, blocks):
217+
layers.append(block(self.inplanes, planes,
218+
norm_layer=norm_layer, bn_eps=bn_eps,
219+
bn_momentum=bn_momentum, inplace=inplace))
220+
221+
return nn.Sequential(*layers)
222+
223+
def forward(self, x):
224+
x = self.conv1(x)
225+
x = self.bn1(x)
226+
x = self.relu(x)
227+
x = self.maxpool(x)
228+
229+
blocks = []
230+
x = self.layer1(x)
231+
blocks.append(x)
232+
x = self.layer2(x)
233+
blocks.append(x)
234+
x = self.layer3(x)
235+
blocks.append(x)
236+
x = self.layer4(x)
237+
blocks.append(x)
238+
239+
return blocks
240+
241+
242+
def resnet18(pretrained_model=None, **kwargs):
243+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
244+
245+
if pretrained_model is not None:
246+
model = load_model(model, pretrained_model)
247+
return model
248+
249+
250+
def resnet34(pretrained_model=None, **kwargs):
251+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
252+
253+
if pretrained_model is not None:
254+
model = load_model(model, pretrained_model)
255+
return model
256+
257+
258+
def resnet50(pretrained_model=None, **kwargs):
259+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
260+
261+
if pretrained_model is not None:
262+
model = load_model(model, pretrained_model)
263+
return model
264+
265+
266+
def resnet101(pretrained_model=None, **kwargs):
267+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
268+
269+
if pretrained_model is not None:
270+
model = load_model(model, pretrained_model)
271+
return model
272+
273+
274+
def resnet152(pretrained_model=None, **kwargs):
275+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
276+
277+
if pretrained_model is not None:
278+
model = load_model(model, pretrained_model)
279+
return model

0 commit comments

Comments
 (0)