Skip to content

Commit 09c03eb

Browse files
committed
support LightCDNet
1 parent 6ef61e4 commit 09c03eb

File tree

11 files changed

+459
-13
lines changed

11 files changed

+459
-13
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Supported change detection model:
3737
- [x] [Changer (TGRS'2023)](configs/changer)
3838
- [x] [HANet (JSTARS'2023)](configs/hanet)
3939
- [x] [TinyCDv2 (Under Review)](configs/tinycd_v2)
40+
- [x] [LightCDNet (GRSL'2023)](configs/lightcdnet)
4041
- [x] [BAN (arXiv'2023)](configs/ban)
4142
- [x] [TTP (arXiv'2023)](configs/ttp)
4243
- [ ] ...
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# model settings
2+
norm_cfg = dict(type='SyncBN', requires_grad=True)
3+
data_preprocessor = dict(
4+
type='DualInputSegDataPreProcessor',
5+
mean=[123.675, 116.28, 103.53] * 2,
6+
std=[58.395, 57.12, 57.375] * 2,
7+
bgr_to_rgb=True,
8+
size_divisor=32,
9+
pad_val=0,
10+
seg_pad_val=255,
11+
test_cfg=dict(size_divisor=32))
12+
model = dict(
13+
type='DIEncoderDecoder',
14+
data_preprocessor=data_preprocessor,
15+
pretrained=None,
16+
backbone=dict(
17+
type='LightCDNet',
18+
stage_repeat_num=[4, 8, 4],
19+
net_type="small"),
20+
neck=dict(
21+
type='TinyFPN',
22+
exist_early_x=True,
23+
early_x_for_fpn=True,
24+
custom_block='conv',
25+
in_channels=[24, 48, 96, 192],
26+
out_channels=48,
27+
num_outs=4),
28+
decode_head=dict(
29+
type='DS_FPNHead',
30+
in_channels=[48, 48, 48, 48],
31+
in_index=[0, 1, 2, 3],
32+
channels=48,
33+
dropout_ratio=0.,
34+
num_classes=2,
35+
norm_cfg=norm_cfg,
36+
align_corners=False,
37+
loss_decode=dict(
38+
type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
39+
auxiliary_head=dict(
40+
type='mmseg.FCNHead',
41+
in_channels=24,
42+
in_index=0,
43+
channels=24,
44+
num_convs=1,
45+
concat_input=False,
46+
dropout_ratio=0.,
47+
num_classes=2,
48+
norm_cfg=norm_cfg,
49+
align_corners=False,
50+
loss_decode=dict(
51+
type='mmseg.CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
52+
# model training and testing settings
53+
train_cfg=dict(),
54+
test_cfg=dict(mode='whole'))

configs/lightcdnet/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# LightCDNet
2+
3+
[LightCDNet: Lightweight Change Detection Network Based on VHR Images](https://ieeexplore.ieee.org/document/10214556)
4+
5+
## Introduction
6+
7+
[Official Repo](https://github.com/NightSongs/LightCDNet)
8+
9+
[Code Snippet](https://github.com/likyoo/open-cd/blob/main/opencd/models/backbones/lightcdnet.py)
10+
11+
## Abstract
12+
Lightweight change detection models are essential for industrial applications and edge devices. Reducing the model size while maintaining high accuracy is a key challenge in developing lightweight change detection models. However, many existing methods oversimplify the model architecture, leading to a loss of information and reduced performance. Therefore, developing a lightweight model that can effectively preserve the input information is a challenging problem. To address this challenge, we propose LightCDNet, a novel lightweight change detection model that effectively preserves the input information. LightCDNet consists of an early fusion backbone network and a pyramid decoder for end-to-end change detection. The core component of LightCDNet is the Deep Supervised Fusion Module (DSFM), which guides the early fusion of primary features to improve performance. We evaluated LightCDNet on the LEVIR-CD dataset and found that it achieved comparable or better performance than state-of-the-art models while being 10–117 times smaller in size.
13+
14+
<!-- [IMAGE] -->
15+
16+
<div align=center>
17+
<img src="https://github.com/likyoo/open-cd/assets/44317497/cec088ca-cb45-4d32-8ebb-c0fd3b8d1a4c" width="90%"/>
18+
</div>
19+
20+
21+
```bibtex
22+
@ARTICLE{10214556,
23+
author={Xing, Yuanjun and Jiang, Jiawei and Xiang, Jun and Yan, Enping and Song, Yabin and Mo, Dengkui},
24+
journal={IEEE Geoscience and Remote Sensing Letters},
25+
title={LightCDNet: Lightweight Change Detection Network Based on VHR Images},
26+
year={2023},
27+
volume={20},
28+
number={},
29+
pages={1-5},
30+
doi={10.1109/LGRS.2023.3304309}}
31+
```
32+
33+
## Results and models
34+
35+
### LEVIR-CD
36+
37+
| Method | Crop Size | Lr schd | \#Param (M) | MACs (G) | Precision | Recall | F1-Score | IoU | config |
38+
| :--------------: | :-------: | :-----: | :---------: | :------: | :-------: | :----: | :------: | :---: | ------------------------------------------------------------ |
39+
| LightCDNet-small | 256x256 | 40000 | 0.35 | 1.65 | 91.36 | 89.81 | 90.57 | 82.77 | [config](https://github.com/likyoo/open-cd/blob/main/configs/lightcdnet/lightcdnet_s_256x256_40k_levircd.py) |
40+
| LightCDNet-base | 256x256 | 40000 | 1.32 | 3.22 | 92.12 | 90.43 | 91.27 | 83.94 | [config](https://github.com/likyoo/open-cd/blob/main/configs/lightcdnet/lightcdnet_b_256x256_40k_levircd.py) |
41+
| LightCDNet-large | 256x256 | 40000 | 2.82 | 5.94 | 92.43 | 90.45 | 91.43 | 84.21 | [config](https://github.com/likyoo/open-cd/blob/main/configs/lightcdnet/lightcdnet_l_256x256_40k_levircd.py) |
42+
43+
44+
- All metrics are based on the category "change".
45+
- All scores are computed on the test set.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = ['./lightcdnet_s_256x256_40k_levircd.py']
2+
3+
model = dict(
4+
backbone=dict(net_type="base"),
5+
neck=dict(in_channels=[24, 116, 232, 464]))
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
_base_ = ['./lightcdnet_s_256x256_40k_levircd.py']
2+
3+
model = dict(
4+
backbone=dict(net_type="large"),
5+
neck=dict(in_channels=[24, 176, 352, 704]))
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
_base_ = [
2+
'../_base_/models/lightcdnet.py',
3+
'../common/standard_256x256_40k_levircd.py']
4+
5+
model = dict(
6+
decode_head=dict(
7+
sampler=dict(type='mmseg.OHEMPixelSampler', thresh=0.7, min_kept=100000)))
8+
9+
# optimizer
10+
optimizer = dict(
11+
type='AdamW',
12+
lr=0.003,
13+
betas=(0.9, 0.999),
14+
weight_decay=0.05)
15+
16+
optim_wrapper = dict(
17+
_delete_=True,
18+
type='OptimWrapper',
19+
optimizer=optimizer)

opencd/models/backbones/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from .hanet import HAN
1010
from .vit_tuner import VisionTransformerTurner
1111
from .vit_sam import ViTSAM_Custom
12+
from .lightcdnet import LightCDNet
1213

1314
__all__ = ['IA_ResNetV1c', 'IA_ResNeSt', 'FC_EF', 'FC_Siam_diff',
1415
'FC_Siam_conc', 'SNUNet_ECAM', 'TinyCD', 'IFN',
1516
'TinyNet', 'IA_MixVisionTransformer', 'HAN',
16-
'VisionTransformerTurner', 'ViTSAM_Custom']
17+
'VisionTransformerTurner', 'ViTSAM_Custom',
18+
'LightCDNet']
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright (c) Open-CD. All rights reserved.
2+
import torch
3+
import torch.nn as nn
4+
import numpy as np
5+
from mmcv.ops import CrissCrossAttention
6+
7+
from mmseg.models.utils import LayerNorm2d
8+
from opencd.registry import MODELS
9+
10+
11+
class CCA(nn.Module):
12+
"""Criss-Cross Attention for Semantic Segmentation.
13+
14+
This head is the implementation of `CCNet
15+
<https://arxiv.org/abs/1811.11721>`_.
16+
17+
Args:
18+
recurrence (int): Number of recurrence of Criss Cross Attention
19+
module. Default: 2.
20+
"""
21+
22+
def __init__(self, channels, recurrence=2):
23+
super(CCA, self).__init__()
24+
self.recurrence = recurrence
25+
self.cca = CrissCrossAttention(channels)
26+
27+
def forward(self, x):
28+
for _ in range(self.recurrence):
29+
x = self.cca(x)
30+
return x
31+
32+
33+
def channel_shuffle(x, groups=2):
34+
bat_size, channels, w, h = x.shape
35+
group_c = channels // groups
36+
x = x.view(bat_size, groups, group_c, w, h)
37+
x = torch.transpose(x, 1, 2).contiguous()
38+
x = x.view(bat_size, -1, w, h)
39+
return x
40+
41+
42+
class ShuffleBlock(nn.Module):
43+
44+
def __init__(self, in_c, out_c, downsample=False):
45+
super(ShuffleBlock, self).__init__()
46+
self.downsample = downsample
47+
half_c = out_c // 2
48+
if downsample:
49+
self.branch1 = nn.Sequential(
50+
# 3*3 dw conv, stride = 2
51+
nn.Conv2d(in_c, in_c, 3, 2, 1, groups=in_c, bias=False),
52+
nn.BatchNorm2d(in_c),
53+
# 1*1 pw conv
54+
nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
55+
nn.BatchNorm2d(half_c),
56+
nn.ReLU(True))
57+
58+
self.branch2 = nn.Sequential(
59+
# 1*1 pw conv
60+
nn.Conv2d(in_c, half_c, 1, 1, 0, bias=False),
61+
nn.BatchNorm2d(half_c),
62+
nn.ReLU(True),
63+
# 3*3 dw conv, stride = 2
64+
nn.Conv2d(half_c, half_c, 3, 2, 1, groups=half_c, bias=False),
65+
nn.BatchNorm2d(half_c),
66+
# 1*1 pw conv
67+
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
68+
nn.BatchNorm2d(half_c),
69+
nn.ReLU(True))
70+
71+
else:
72+
assert in_c == out_c
73+
74+
self.branch2 = nn.Sequential(
75+
# 1*1 pw conv
76+
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
77+
nn.BatchNorm2d(half_c),
78+
nn.ReLU(True),
79+
# 3*3 dw conv, stride = 1
80+
nn.Conv2d(half_c, half_c, 3, 1, 1, groups=half_c, bias=False),
81+
nn.BatchNorm2d(half_c),
82+
# 1*1 pw conv
83+
nn.Conv2d(half_c, half_c, 1, 1, 0, bias=False),
84+
nn.BatchNorm2d(half_c),
85+
nn.ReLU(True))
86+
87+
def forward(self, x):
88+
out = None
89+
if self.downsample:
90+
# if it is downsampling, we don't need to do channel split
91+
out = torch.cat((self.branch1(x), self.branch2(x)), 1)
92+
else:
93+
# channel split
94+
channels = x.shape[1]
95+
c = channels // 2
96+
x1 = x[:, :c, :, :]
97+
x2 = x[:, c:, :, :]
98+
out = torch.cat((x1, self.branch2(x2)), 1)
99+
100+
return channel_shuffle(out, 2)
101+
102+
103+
class TimeAttention(nn.Module):
104+
105+
def __init__(self, channels):
106+
super(TimeAttention, self).__init__()
107+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
108+
attn_channels = channels // 16
109+
attn_channels = max(attn_channels, 8)
110+
self.mlp = nn.Sequential(
111+
nn.Conv2d(channels * 2, attn_channels, kernel_size=1, bias=False),
112+
nn.BatchNorm2d(attn_channels),
113+
nn.ReLU(),
114+
nn.Conv2d(attn_channels, channels * 2, kernel_size=1, bias=False),
115+
)
116+
117+
def forward(self, x1, x2):
118+
x = torch.cat((x1, x2), dim=1)
119+
x = self.avg_pool(x)
120+
y = self.mlp(x)
121+
B, C, H, W = y.size()
122+
x1_attn, x2_attn = y.reshape(B, 2, C // 2, H, W).transpose(0, 1)
123+
x1_attn = torch.sigmoid(x1_attn)
124+
x2_attn = torch.sigmoid(x2_attn)
125+
x1 = x1 * x1_attn + x1
126+
x2 = x2 * x2_attn + x2
127+
return x1, x2
128+
129+
130+
class shuffle_fusion(nn.Module):
131+
132+
def __init__(self, channels, block_num=2):
133+
super().__init__()
134+
135+
self.stages = []
136+
self.stages.append(
137+
nn.Sequential(
138+
nn.Conv2d(channels, channels * 4, kernel_size=1, bias=False),
139+
nn.BatchNorm2d(channels * 4), nn.ReLU()))
140+
for i in range(block_num):
141+
self.stages.append(
142+
ShuffleBlock(channels * 4, channels * 4, downsample=False))
143+
144+
self.stages = nn.Sequential(*self.stages)
145+
146+
self.single_conv = nn.Sequential(
147+
nn.Conv2d(channels * 4, channels, kernel_size=1, bias=False),
148+
nn.BatchNorm2d(channels), nn.ReLU())
149+
150+
self.time_attn = TimeAttention(channels)
151+
152+
self.final_conv = nn.Sequential(
153+
nn.Conv2d(channels * 2, channels, kernel_size=1, bias=False),
154+
nn.BatchNorm2d(channels), nn.ReLU())
155+
156+
def forward_single(self, x):
157+
identity = x
158+
x = self.stages(x)
159+
x = self.single_conv(x)
160+
x = identity + x
161+
return x
162+
163+
def forward(self, x1, x2):
164+
x1 = self.forward_single(x1)
165+
x2 = self.forward_single(x2)
166+
x1, x2 = self.time_attn(x1, x2)
167+
x = self.final_conv(channel_shuffle(torch.cat((x1, x2), dim=1)))
168+
return x
169+
170+
171+
@MODELS.register_module()
172+
class LightCDNet(nn.Module):
173+
174+
def __init__(self, stage_repeat_num, net_type="small"):
175+
super(LightCDNet, self).__init__()
176+
177+
index_list = stage_repeat_num.copy()
178+
index_list[0] = index_list[0] - 1
179+
self.index_list = list(np.cumsum(index_list))
180+
if net_type == "small":
181+
self.out_channels = [24, 48, 96, 192]
182+
self.block_num = 4
183+
elif net_type == "base":
184+
self.out_channels = [24, 116, 232, 464]
185+
self.block_num = 8
186+
elif net_type == "large":
187+
self.out_channels = [24, 176, 352, 704]
188+
self.block_num = 16
189+
else:
190+
print("the model type is error!")
191+
192+
self.conv1 = nn.Sequential(
193+
nn.Conv2d(3, self.out_channels[0], 3, 2, 1, bias=False),
194+
LayerNorm2d(self.out_channels[0]), nn.GELU())
195+
196+
self.fusion_conv = shuffle_fusion(
197+
self.out_channels[0], block_num=self.block_num)
198+
199+
in_c = self.out_channels[0]
200+
201+
self.stages = []
202+
for stage_idx in range(len(stage_repeat_num)):
203+
out_c = self.out_channels[1 + stage_idx]
204+
repeat_num = stage_repeat_num[stage_idx]
205+
for i in range(repeat_num):
206+
if i == 0:
207+
self.stages.append(
208+
ShuffleBlock(in_c, out_c, downsample=True))
209+
else:
210+
self.stages.append(
211+
ShuffleBlock(in_c, in_c, downsample=False))
212+
in_c = out_c
213+
self.stages.append(CCA(channels=out_c, recurrence=2))
214+
215+
self.stages = nn.Sequential(*self.stages)
216+
217+
def forward(self, x1, x2):
218+
x1 = self.conv1(x1)
219+
x2 = self.conv1(x2)
220+
x = self.fusion_conv(x1, x2)
221+
outs = [x]
222+
223+
for i in range(len(self.stages)):
224+
x = self.stages[i](x)
225+
if i in self.index_list:
226+
outs.append(x)
227+
return outs

opencd/models/decode_heads/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from .ban_head import BitemporalAdapterHead
99
from .ban_utils import BAN_MLPDecoder, BAN_BITHead
1010
from .mlpseg_head import MLPSegHead
11+
from .ds_fpn_head import DS_FPNHead
1112

1213
__all__ = ['BITHead', 'Changer', 'IdentityHead', 'DSIdentityHead', 'TinyHead',
1314
'STAHead', 'MultiHeadDecoder', 'GeneralSCDHead', 'BitemporalAdapterHead',
14-
'BAN_MLPDecoder', 'BAN_BITHead', 'MLPSegHead']
15+
'BAN_MLPDecoder', 'BAN_BITHead', 'MLPSegHead', 'DS_FPNHead']

0 commit comments

Comments
 (0)