22# Copyright (c) OpenMMLab.
33import torch
44import torch .nn as nn
5- from sscma . models . base import ConvModule , DepthwiseSeparableConvModule
5+ from mmdet . utils import ConfigType , OptConfigType , OptMultiConfig
66from mmengine .model import BaseModule
77from torch import Tensor
88
9- from mmdet . utils import ConfigType , OptConfigType , OptMultiConfig
9+ from sscma . models . base import ConvNormActivation
1010
11-
12- class ChannelAttention (BaseModule ):
13- def __init__ (self , channels : int , init_cfg : OptMultiConfig = None ) -> None :
14- super ().__init__ (init_cfg = init_cfg )
15- self .global_avgpool = nn .AdaptiveAvgPool2d (1 )
16- self .fc = nn .Conv2d (channels , channels , 1 , 1 , 0 , bias = True )
17- self .act = nn .Hardsigmoid (inplace = True )
18-
19- def forward (self , x : Tensor ) -> Tensor :
20- """Forward function for ChannelAttention."""
21- with torch .cuda .amp .autocast (enabled = False ):
22- out = self .global_avgpool (x )
23- out = self .fc (out )
24- out = self .act (out )
25- return x * out
11+ from .attention import ChannelAttention
2612
2713
2814class CSPLayer (BaseModule ):
29- def __init__ (self ,
30- in_channels : int ,
31- out_channels : int ,
32- expand_ratio : float = 0.5 ,
33- num_blocks : int = 1 ,
34- add_identity : bool = True ,
35- use_depthwise : bool = False ,
36- use_cspnext_block : bool = False ,
37- channel_attention : bool = False ,
38- conv_cfg : OptConfigType = None ,
39- norm_cfg : ConfigType = dict (
40- type = 'BN' , momentum = 0.03 , eps = 0.001 ),
41- act_cfg : ConfigType = dict (type = 'Swish' ),
42- init_cfg : OptMultiConfig = None ) -> None :
15+ def __init__ (
16+ self ,
17+ in_channels : int ,
18+ out_channels : int ,
19+ expand_ratio : float = 0.5 ,
20+ num_blocks : int = 1 ,
21+ add_identity : bool = True ,
22+ use_depthwise : bool = False ,
23+ use_cspnext_block : bool = False ,
24+ channel_attention : bool = False ,
25+ conv_cfg : OptConfigType = None ,
26+ norm_cfg : ConfigType = dict (type = 'BN' , momentum = 0.03 , eps = 0.001 ),
27+ act_cfg : ConfigType = dict (type = 'Swish' ),
28+ init_cfg : OptMultiConfig = None ,
29+ ) -> None :
4330 super ().__init__ (init_cfg = init_cfg )
4431 block = CSPNeXtBlock if use_cspnext_block else DarknetBottleneck
4532 mid_channels = int (out_channels * expand_ratio )
4633 self .channel_attention = channel_attention
47- self .main_conv = ConvModule (
48- in_channels ,
49- mid_channels ,
50- 1 ,
51- conv_cfg = conv_cfg ,
52- norm_cfg = norm_cfg ,
53- act_cfg = act_cfg )
54- self .short_conv = ConvModule (
55- in_channels ,
56- mid_channels ,
57- 1 ,
58- conv_cfg = conv_cfg ,
59- norm_cfg = norm_cfg ,
60- act_cfg = act_cfg )
61- self .final_conv = ConvModule (
62- 2 * mid_channels ,
63- out_channels ,
64- 1 ,
65- conv_cfg = conv_cfg ,
66- norm_cfg = norm_cfg ,
67- act_cfg = act_cfg )
68-
69- self .blocks = nn .Sequential (* [
70- block (
71- mid_channels ,
72- mid_channels ,
73- 1.0 ,
74- add_identity ,
75- use_depthwise ,
76- conv_cfg = conv_cfg ,
77- norm_cfg = norm_cfg ,
78- act_cfg = act_cfg ) for _ in range (num_blocks )
79- ])
34+ self .main_conv = ConvNormActivation (
35+ in_channels , mid_channels , 1 , conv_layer = conv_cfg , norm_layer = norm_cfg , activation_layer = act_cfg
36+ )
37+ self .short_conv = ConvNormActivation (
38+ in_channels , mid_channels , 1 , conv_layer = conv_cfg , norm_layer = norm_cfg , activation_layer = act_cfg
39+ )
40+ self .final_conv = ConvNormActivation (
41+ 2 * mid_channels , out_channels , 1 , conv_layer = conv_cfg , norm_layer = norm_cfg , activation_layer = act_cfg
42+ )
43+
44+ self .blocks = nn .Sequential (
45+ * [
46+ block (
47+ mid_channels ,
48+ mid_channels ,
49+ 1.0 ,
50+ add_identity ,
51+ use_depthwise ,
52+ conv_cfg = conv_cfg ,
53+ norm_cfg = norm_cfg ,
54+ act_cfg = act_cfg ,
55+ )
56+ for _ in range (num_blocks )
57+ ]
58+ )
8059 if channel_attention :
8160 self .attention = ChannelAttention (2 * mid_channels )
8261
@@ -92,40 +71,44 @@ def forward(self, x: Tensor) -> Tensor:
9271 if self .channel_attention :
9372 x_final = self .attention (x_final )
9473 return self .final_conv (x_final )
95-
74+
75+
9676class DarknetBottleneck (BaseModule ):
97- def __init__ (self ,
98- in_channels : int ,
99- out_channels : int ,
100- expansion : float = 0.5 ,
101- add_identity : bool = True ,
102- use_depthwise : bool = False ,
103- conv_cfg : OptConfigType = None ,
104- norm_cfg : ConfigType = dict (
105- type = 'BN' , momentum = 0.03 , eps = 0.001 ),
106- act_cfg : ConfigType = dict (type = 'Swish' ),
107- init_cfg : OptMultiConfig = None ) -> None :
77+ def __init__ (
78+ self ,
79+ in_channels : int ,
80+ out_channels : int ,
81+ expansion : float = 0.5 ,
82+ add_identity : bool = True ,
83+ use_depthwise : bool = False ,
84+ conv_cfg : OptConfigType = None ,
85+ norm_cfg : ConfigType = dict (type = 'BN' , momentum = 0.03 , eps = 0.001 ),
86+ act_cfg : ConfigType = dict (type = 'Swish' ),
87+ init_cfg : OptMultiConfig = None ,
88+ ) -> None :
10889 super ().__init__ (init_cfg = init_cfg )
10990 hidden_channels = int (out_channels * expansion )
110- conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
111- self .conv1 = ConvModule (
91+ self .conv1 = ConvNormActivation (
11292 in_channels ,
11393 hidden_channels ,
11494 1 ,
115- conv_cfg = conv_cfg ,
116- norm_cfg = norm_cfg ,
117- act_cfg = act_cfg )
118- self .conv2 = conv (
95+ conv_layer = conv_cfg ,
96+ norm_layer = norm_cfg ,
97+ activation_layer = act_cfg ,
98+ use_depthwise = False ,
99+ )
100+ self .conv2 = ConvNormActivation (
119101 hidden_channels ,
120102 out_channels ,
121103 3 ,
122104 stride = 1 ,
123105 padding = 1 ,
124- conv_cfg = conv_cfg ,
125- norm_cfg = norm_cfg ,
126- act_cfg = act_cfg )
127- self .add_identity = \
128- add_identity and in_channels == out_channels
106+ conv_layer = conv_cfg ,
107+ norm_layer = norm_cfg ,
108+ activation_layer = act_cfg ,
109+ use_depthwise = use_depthwise ,
110+ )
111+ self .add_identity = add_identity and in_channels == out_channels
129112
130113 def forward (self , x : Tensor ) -> Tensor :
131114 """Forward function."""
@@ -140,40 +123,43 @@ def forward(self, x: Tensor) -> Tensor:
140123
141124
142125class CSPNeXtBlock (BaseModule ):
143- def __init__ (self ,
144- in_channels : int ,
145- out_channels : int ,
146- expansion : float = 0.5 ,
147- add_identity : bool = True ,
148- use_depthwise : bool = False ,
149- kernel_size : int = 5 ,
150- conv_cfg : OptConfigType = None ,
151- norm_cfg : ConfigType = dict (
152- type = 'BN' , momentum = 0.03 , eps = 0.001 ),
153- act_cfg : ConfigType = dict (type = 'SiLU' ),
154- init_cfg : OptMultiConfig = None ) -> None :
126+ def __init__ (
127+ self ,
128+ in_channels : int ,
129+ out_channels : int ,
130+ expansion : float = 0.5 ,
131+ add_identity : bool = True ,
132+ use_depthwise : bool = False ,
133+ kernel_size : int = 5 ,
134+ conv_cfg : OptConfigType = None ,
135+ norm_cfg : ConfigType = dict (type = 'BN' , momentum = 0.03 , eps = 0.001 ),
136+ act_cfg : ConfigType = dict (type = 'SiLU' ),
137+ init_cfg : OptMultiConfig = None ,
138+ ) -> None :
155139 super ().__init__ (init_cfg = init_cfg )
156140 hidden_channels = int (out_channels * expansion )
157- conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule
158- self .conv1 = conv (
141+ self .conv1 = ConvNormActivation (
159142 in_channels ,
160143 hidden_channels ,
161144 3 ,
162145 stride = 1 ,
163146 padding = 1 ,
164- norm_cfg = norm_cfg ,
165- act_cfg = act_cfg )
166- self .conv2 = DepthwiseSeparableConvModule (
147+ norm_layer = norm_cfg ,
148+ activation_layer = act_cfg ,
149+ use_depthwise = use_depthwise ,
150+ )
151+ self .conv2 = ConvNormActivation (
167152 hidden_channels ,
168153 out_channels ,
169154 kernel_size ,
170155 stride = 1 ,
171156 padding = kernel_size // 2 ,
172- conv_cfg = conv_cfg ,
173- norm_cfg = norm_cfg ,
174- act_cfg = act_cfg )
175- self .add_identity = \
176- add_identity and in_channels == out_channels
157+ conv_layer = conv_cfg ,
158+ norm_layer = norm_cfg ,
159+ activation_layer = act_cfg ,
160+ use_depthwise = True ,
161+ )
162+ self .add_identity = add_identity and in_channels == out_channels
177163
178164 def forward (self , x : Tensor ) -> Tensor :
179165 """Forward function."""
0 commit comments