Skip to content

Commit 532308e

Browse files
authored
Refractor: Merge Conv Module and Channel Attention Module (#219)
* Fix: Remove some reimported packages * Refractor: Merge Conv and DepthwiseConv Module into one * Refractor: Merge Channel Attention
1 parent 1b17b8f commit 532308e

File tree

6 files changed

+174
-178
lines changed

6 files changed

+174
-178
lines changed

sscma/models/base/general.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(
7878
conv_layer: Optional[Callable[..., nn.Module]] or Dict or AnyStr = None,
7979
dilation: int = 1,
8080
inplace: bool = True,
81+
use_depthwise: bool = False,
8182
) -> None:
8283
super().__init__()
8384
if padding is None:
@@ -86,17 +87,41 @@ def __init__(
8687
conv_layer = nn.Conv2d
8788
else:
8889
conv_layer = get_conv(conv_layer)
89-
conv = conv_layer(
90-
in_channels,
91-
out_channels,
92-
kernel_size,
93-
stride,
94-
padding,
95-
dilation=dilation,
96-
groups=groups,
97-
bias=norm_layer is None if bias is None else bias,
98-
)
99-
self.add_module('conv', conv)
90+
if use_depthwise:
91+
dw_conv = conv_layer(
92+
in_channels,
93+
in_channels,
94+
kernel_size,
95+
stride,
96+
padding,
97+
dilation=dilation,
98+
groups=in_channels,
99+
bias=norm_layer is None if bias is None else bias,
100+
)
101+
pw_conv = conv_layer(
102+
in_channels,
103+
out_channels,
104+
1,
105+
stride,
106+
padding,
107+
dilation=dilation,
108+
groups=1,
109+
bias=norm_layer is None if bias is None else bias,
110+
)
111+
self.add_module('dw_conv', dw_conv)
112+
self.add_module('pw_conv', pw_conv)
113+
else:
114+
conv = conv_layer(
115+
in_channels,
116+
out_channels,
117+
kernel_size,
118+
stride,
119+
padding,
120+
dilation=dilation,
121+
groups=groups,
122+
bias=norm_layer is None if bias is None else bias,
123+
)
124+
self.add_module('conv', conv)
100125
if norm_layer is not None:
101126
norm_layer = get_norm(norm_layer)
102127
self.add_module('norm', norm_layer(out_channels))

sscma/models/detectors/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
# Copyright (c) Seeed Technology Co.,Ltd. All rights reserved.
2-
from typing import Dict, Optional, Union, Tuple, List
3-
from abc import ABCMeta, abstractmethod
42
import copy
3+
from abc import ABCMeta, abstractmethod
4+
from typing import Dict, List, Optional, Tuple, Union
55

6+
import torch
67
from mmdet.models.detectors import BaseDetector, SemiBaseDetector
7-
from mmdet.structures import DetDataSample, OptSampleList, SampleList
8-
from mmdet.utils import OptConfigType, OptMultiConfig, ConfigType, InstanceList
98
from mmdet.models.utils import rename_loss_dict, reweight_loss_dict
10-
from mmdet.structures import SampleList
9+
from mmdet.structures import DetDataSample, OptSampleList, SampleList
10+
from mmdet.utils import ConfigType, InstanceList, OptConfigType, OptMultiConfig
1111
from mmengine.model import BaseModel
12-
import torch
13-
from torch import Tensor
1412
from mmengine.optim import OptimWrapper
15-
from ..utils import samplelist_boxtype2tensor
13+
from torch import Tensor
1614

17-
from sscma.registry import MODELS
1815
from sscma.models.semi import BasePseudoLabelCreator
16+
from sscma.registry import MODELS
1917

18+
from ..utils import samplelist_boxtype2tensor
2019

2120
ForwardResults = Union[Dict[str, torch.Tensor], List[DetDataSample], Tuple[torch.Tensor], torch.Tensor]
2221

22+
2323
@MODELS.register_module()
2424
class BaseSsod(SemiBaseDetector):
2525
teacher: BaseDetector

sscma/models/layers/csp_layer.py

Lines changed: 95 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -2,81 +2,60 @@
22
# Copyright (c) OpenMMLab.
33
import torch
44
import torch.nn as nn
5-
from sscma.models.base import ConvModule, DepthwiseSeparableConvModule
5+
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
66
from mmengine.model import BaseModule
77
from torch import Tensor
88

9-
from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig
9+
from sscma.models.base import ConvNormActivation
1010

11-
12-
class ChannelAttention(BaseModule):
13-
def __init__(self, channels: int, init_cfg: OptMultiConfig = None) -> None:
14-
super().__init__(init_cfg=init_cfg)
15-
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
16-
self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
17-
self.act = nn.Hardsigmoid(inplace=True)
18-
19-
def forward(self, x: Tensor) -> Tensor:
20-
"""Forward function for ChannelAttention."""
21-
with torch.cuda.amp.autocast(enabled=False):
22-
out = self.global_avgpool(x)
23-
out = self.fc(out)
24-
out = self.act(out)
25-
return x * out
11+
from .attention import ChannelAttention
2612

2713

2814
class CSPLayer(BaseModule):
29-
def __init__(self,
30-
in_channels: int,
31-
out_channels: int,
32-
expand_ratio: float = 0.5,
33-
num_blocks: int = 1,
34-
add_identity: bool = True,
35-
use_depthwise: bool = False,
36-
use_cspnext_block: bool = False,
37-
channel_attention: bool = False,
38-
conv_cfg: OptConfigType = None,
39-
norm_cfg: ConfigType = dict(
40-
type='BN', momentum=0.03, eps=0.001),
41-
act_cfg: ConfigType = dict(type='Swish'),
42-
init_cfg: OptMultiConfig = None) -> None:
15+
def __init__(
16+
self,
17+
in_channels: int,
18+
out_channels: int,
19+
expand_ratio: float = 0.5,
20+
num_blocks: int = 1,
21+
add_identity: bool = True,
22+
use_depthwise: bool = False,
23+
use_cspnext_block: bool = False,
24+
channel_attention: bool = False,
25+
conv_cfg: OptConfigType = None,
26+
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
27+
act_cfg: ConfigType = dict(type='Swish'),
28+
init_cfg: OptMultiConfig = None,
29+
) -> None:
4330
super().__init__(init_cfg=init_cfg)
4431
block = CSPNeXtBlock if use_cspnext_block else DarknetBottleneck
4532
mid_channels = int(out_channels * expand_ratio)
4633
self.channel_attention = channel_attention
47-
self.main_conv = ConvModule(
48-
in_channels,
49-
mid_channels,
50-
1,
51-
conv_cfg=conv_cfg,
52-
norm_cfg=norm_cfg,
53-
act_cfg=act_cfg)
54-
self.short_conv = ConvModule(
55-
in_channels,
56-
mid_channels,
57-
1,
58-
conv_cfg=conv_cfg,
59-
norm_cfg=norm_cfg,
60-
act_cfg=act_cfg)
61-
self.final_conv = ConvModule(
62-
2 * mid_channels,
63-
out_channels,
64-
1,
65-
conv_cfg=conv_cfg,
66-
norm_cfg=norm_cfg,
67-
act_cfg=act_cfg)
68-
69-
self.blocks = nn.Sequential(*[
70-
block(
71-
mid_channels,
72-
mid_channels,
73-
1.0,
74-
add_identity,
75-
use_depthwise,
76-
conv_cfg=conv_cfg,
77-
norm_cfg=norm_cfg,
78-
act_cfg=act_cfg) for _ in range(num_blocks)
79-
])
34+
self.main_conv = ConvNormActivation(
35+
in_channels, mid_channels, 1, conv_layer=conv_cfg, norm_layer=norm_cfg, activation_layer=act_cfg
36+
)
37+
self.short_conv = ConvNormActivation(
38+
in_channels, mid_channels, 1, conv_layer=conv_cfg, norm_layer=norm_cfg, activation_layer=act_cfg
39+
)
40+
self.final_conv = ConvNormActivation(
41+
2 * mid_channels, out_channels, 1, conv_layer=conv_cfg, norm_layer=norm_cfg, activation_layer=act_cfg
42+
)
43+
44+
self.blocks = nn.Sequential(
45+
*[
46+
block(
47+
mid_channels,
48+
mid_channels,
49+
1.0,
50+
add_identity,
51+
use_depthwise,
52+
conv_cfg=conv_cfg,
53+
norm_cfg=norm_cfg,
54+
act_cfg=act_cfg,
55+
)
56+
for _ in range(num_blocks)
57+
]
58+
)
8059
if channel_attention:
8160
self.attention = ChannelAttention(2 * mid_channels)
8261

@@ -92,40 +71,44 @@ def forward(self, x: Tensor) -> Tensor:
9271
if self.channel_attention:
9372
x_final = self.attention(x_final)
9473
return self.final_conv(x_final)
95-
74+
75+
9676
class DarknetBottleneck(BaseModule):
97-
def __init__(self,
98-
in_channels: int,
99-
out_channels: int,
100-
expansion: float = 0.5,
101-
add_identity: bool = True,
102-
use_depthwise: bool = False,
103-
conv_cfg: OptConfigType = None,
104-
norm_cfg: ConfigType = dict(
105-
type='BN', momentum=0.03, eps=0.001),
106-
act_cfg: ConfigType = dict(type='Swish'),
107-
init_cfg: OptMultiConfig = None) -> None:
77+
def __init__(
78+
self,
79+
in_channels: int,
80+
out_channels: int,
81+
expansion: float = 0.5,
82+
add_identity: bool = True,
83+
use_depthwise: bool = False,
84+
conv_cfg: OptConfigType = None,
85+
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
86+
act_cfg: ConfigType = dict(type='Swish'),
87+
init_cfg: OptMultiConfig = None,
88+
) -> None:
10889
super().__init__(init_cfg=init_cfg)
10990
hidden_channels = int(out_channels * expansion)
110-
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
111-
self.conv1 = ConvModule(
91+
self.conv1 = ConvNormActivation(
11292
in_channels,
11393
hidden_channels,
11494
1,
115-
conv_cfg=conv_cfg,
116-
norm_cfg=norm_cfg,
117-
act_cfg=act_cfg)
118-
self.conv2 = conv(
95+
conv_layer=conv_cfg,
96+
norm_layer=norm_cfg,
97+
activation_layer=act_cfg,
98+
use_depthwise=False,
99+
)
100+
self.conv2 = ConvNormActivation(
119101
hidden_channels,
120102
out_channels,
121103
3,
122104
stride=1,
123105
padding=1,
124-
conv_cfg=conv_cfg,
125-
norm_cfg=norm_cfg,
126-
act_cfg=act_cfg)
127-
self.add_identity = \
128-
add_identity and in_channels == out_channels
106+
conv_layer=conv_cfg,
107+
norm_layer=norm_cfg,
108+
activation_layer=act_cfg,
109+
use_depthwise=use_depthwise,
110+
)
111+
self.add_identity = add_identity and in_channels == out_channels
129112

130113
def forward(self, x: Tensor) -> Tensor:
131114
"""Forward function."""
@@ -140,40 +123,43 @@ def forward(self, x: Tensor) -> Tensor:
140123

141124

142125
class CSPNeXtBlock(BaseModule):
143-
def __init__(self,
144-
in_channels: int,
145-
out_channels: int,
146-
expansion: float = 0.5,
147-
add_identity: bool = True,
148-
use_depthwise: bool = False,
149-
kernel_size: int = 5,
150-
conv_cfg: OptConfigType = None,
151-
norm_cfg: ConfigType = dict(
152-
type='BN', momentum=0.03, eps=0.001),
153-
act_cfg: ConfigType = dict(type='SiLU'),
154-
init_cfg: OptMultiConfig = None) -> None:
126+
def __init__(
127+
self,
128+
in_channels: int,
129+
out_channels: int,
130+
expansion: float = 0.5,
131+
add_identity: bool = True,
132+
use_depthwise: bool = False,
133+
kernel_size: int = 5,
134+
conv_cfg: OptConfigType = None,
135+
norm_cfg: ConfigType = dict(type='BN', momentum=0.03, eps=0.001),
136+
act_cfg: ConfigType = dict(type='SiLU'),
137+
init_cfg: OptMultiConfig = None,
138+
) -> None:
155139
super().__init__(init_cfg=init_cfg)
156140
hidden_channels = int(out_channels * expansion)
157-
conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
158-
self.conv1 = conv(
141+
self.conv1 = ConvNormActivation(
159142
in_channels,
160143
hidden_channels,
161144
3,
162145
stride=1,
163146
padding=1,
164-
norm_cfg=norm_cfg,
165-
act_cfg=act_cfg)
166-
self.conv2 = DepthwiseSeparableConvModule(
147+
norm_layer=norm_cfg,
148+
activation_layer=act_cfg,
149+
use_depthwise=use_depthwise,
150+
)
151+
self.conv2 = ConvNormActivation(
167152
hidden_channels,
168153
out_channels,
169154
kernel_size,
170155
stride=1,
171156
padding=kernel_size // 2,
172-
conv_cfg=conv_cfg,
173-
norm_cfg=norm_cfg,
174-
act_cfg=act_cfg)
175-
self.add_identity = \
176-
add_identity and in_channels == out_channels
157+
conv_layer=conv_cfg,
158+
norm_layer=norm_cfg,
159+
activation_layer=act_cfg,
160+
use_depthwise=True,
161+
)
162+
self.add_identity = add_identity and in_channels == out_channels
177163

178164
def forward(self, x: Tensor) -> Tensor:
179165
"""Forward function."""

0 commit comments

Comments
 (0)