4444import torch .nn as nn
4545
4646from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
47- from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , calculate_drop_path_rates , Mlp , GlobalResponseNormMlp , \
48- LayerNorm2d , LayerNorm , RmsNorm2d , RmsNorm , create_conv2d , get_act_layer , get_norm_layer , make_divisible , to_ntuple
49- from timm .layers import SimpleNorm2d , SimpleNorm
50- from timm .layers import NormMlpClassifierHead , ClassifierHead
47+ from timm .layers import (
48+ trunc_normal_ ,
49+ AvgPool2dSame ,
50+ DropPath ,
51+ calculate_drop_path_rates ,
52+ Mlp ,
53+ GlobalResponseNormMlp ,
54+ LayerNorm2d ,
55+ LayerNorm ,
56+ RmsNorm2d ,
57+ RmsNorm ,
58+ SimpleNorm2d ,
59+ SimpleNorm ,
60+ create_conv2d ,
61+ get_act_layer ,
62+ get_norm_layer ,
63+ make_divisible ,
64+ to_ntuple ,
65+ NormMlpClassifierHead ,
66+ ClassifierHead ,
67+ )
5168from ._builder import build_model_with_cfg
5269from ._features import feature_take_indices
5370from ._manipulate import named_apply , checkpoint_seq
5976class Downsample (nn .Module ):
6077 """Downsample module for ConvNeXt."""
6178
62- def __init__ (self , in_chs : int , out_chs : int , stride : int = 1 , dilation : int = 1 ) -> None :
79+ def __init__ (
80+ self ,
81+ in_chs : int ,
82+ out_chs : int ,
83+ stride : int = 1 ,
84+ dilation : int = 1 ,
85+ device = None ,
86+ dtype = None ,
87+ ) -> None :
6388 """Initialize Downsample module.
6489
6590 Args:
@@ -68,6 +93,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1
6893 stride: Stride for downsampling.
6994 dilation: Dilation rate.
7095 """
96+ dd = {'device' : device , 'dtype' : dtype }
7197 super ().__init__ ()
7298 avg_stride = stride if dilation == 1 else 1
7399 if stride > 1 or dilation > 1 :
@@ -77,7 +103,7 @@ def __init__(self, in_chs: int, out_chs: int, stride: int = 1, dilation: int = 1
77103 self .pool = nn .Identity ()
78104
79105 if in_chs != out_chs :
80- self .conv = create_conv2d (in_chs , out_chs , 1 , stride = 1 )
106+ self .conv = create_conv2d (in_chs , out_chs , 1 , stride = 1 , ** dd )
81107 else :
82108 self .conv = nn .Identity ()
83109
@@ -115,6 +141,8 @@ def __init__(
115141 act_layer : Union [str , Callable ] = 'gelu' ,
116142 norm_layer : Optional [Callable ] = None ,
117143 drop_path : float = 0. ,
144+ device = None ,
145+ dtype = None ,
118146 ):
119147 """
120148
@@ -133,6 +161,7 @@ def __init__(
133161 norm_layer: Normalization layer (defaults to LN if not specified).
134162 drop_path: Stochastic depth probability.
135163 """
164+ dd = {'device' : device , 'dtype' : dtype }
136165 super ().__init__ ()
137166 out_chs = out_chs or in_chs
138167 dilation = to_ntuple (2 )(dilation )
@@ -149,12 +178,18 @@ def __init__(
149178 dilation = dilation [0 ],
150179 depthwise = True ,
151180 bias = conv_bias ,
181+ ** dd ,
182+ )
183+ self .norm = norm_layer (out_chs , ** dd )
184+ self .mlp = mlp_layer (
185+ out_chs ,
186+ int (mlp_ratio * out_chs ),
187+ act_layer = act_layer ,
188+ ** dd ,
152189 )
153- self .norm = norm_layer (out_chs )
154- self .mlp = mlp_layer (out_chs , int (mlp_ratio * out_chs ), act_layer = act_layer )
155- self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs )) if ls_init_value is not None else None
190+ self .gamma = nn .Parameter (ls_init_value * torch .ones (out_chs , ** dd )) if ls_init_value is not None else None
156191 if in_chs != out_chs or stride != 1 or dilation [0 ] != dilation [1 ]:
157- self .shortcut = Downsample (in_chs , out_chs , stride = stride , dilation = dilation [0 ])
192+ self .shortcut = Downsample (in_chs , out_chs , stride = stride , dilation = dilation [0 ], ** dd )
158193 else :
159194 self .shortcut = nn .Identity ()
160195 self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
@@ -196,7 +231,9 @@ def __init__(
196231 use_grn : bool = False ,
197232 act_layer : Union [str , Callable ] = 'gelu' ,
198233 norm_layer : Optional [Callable ] = None ,
199- norm_layer_cl : Optional [Callable ] = None
234+ norm_layer_cl : Optional [Callable ] = None ,
235+ device = None ,
236+ dtype = None ,
200237 ) -> None :
201238 """Initialize ConvNeXt stage.
202239
@@ -216,14 +253,15 @@ def __init__(
216253 norm_layer: Normalization layer.
217254 norm_layer_cl: Normalization layer for channels last.
218255 """
256+ dd = {'device' : device , 'dtype' : dtype }
219257 super ().__init__ ()
220258 self .grad_checkpointing = False
221259
222260 if in_chs != out_chs or stride > 1 or dilation [0 ] != dilation [1 ]:
223261 ds_ks = 2 if stride > 1 or dilation [0 ] != dilation [1 ] else 1
224262 pad = 'same' if dilation [1 ] > 1 else 0 # same padding needed if dilation used
225263 self .downsample = nn .Sequential (
226- norm_layer (in_chs ),
264+ norm_layer (in_chs , ** dd ),
227265 create_conv2d (
228266 in_chs ,
229267 out_chs ,
@@ -232,6 +270,7 @@ def __init__(
232270 dilation = dilation [0 ],
233271 padding = pad ,
234272 bias = conv_bias ,
273+ ** dd ,
235274 ),
236275 )
237276 in_chs = out_chs
@@ -253,6 +292,7 @@ def __init__(
253292 use_grn = use_grn ,
254293 act_layer = act_layer ,
255294 norm_layer = norm_layer if conv_mlp else norm_layer_cl ,
295+ ** dd ,
256296 ))
257297 in_chs = out_chs
258298 self .blocks = nn .Sequential (* stage_blocks )
@@ -324,6 +364,8 @@ def __init__(
324364 norm_eps : Optional [float ] = None ,
325365 drop_rate : float = 0. ,
326366 drop_path_rate : float = 0. ,
367+ device = None ,
368+ dtype = None ,
327369 ):
328370 """
329371 Args:
@@ -349,6 +391,7 @@ def __init__(
349391 drop_path_rate: Stochastic depth drop rate.
350392 """
351393 super ().__init__ ()
394+ dd = {'device' : device , 'dtype' : dtype }
352395 assert output_stride in (8 , 16 , 32 )
353396 kernel_sizes = to_ntuple (4 )(kernel_sizes )
354397 norm_layer , norm_layer_cl = _get_norm_layers (norm_layer , conv_mlp , norm_eps )
@@ -362,17 +405,17 @@ def __init__(
362405 if stem_type == 'patch' :
363406 # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
364407 self .stem = nn .Sequential (
365- nn .Conv2d (in_chans , dims [0 ], kernel_size = patch_size , stride = patch_size , bias = conv_bias ),
366- norm_layer (dims [0 ]),
408+ nn .Conv2d (in_chans , dims [0 ], kernel_size = patch_size , stride = patch_size , bias = conv_bias , ** dd ),
409+ norm_layer (dims [0 ], ** dd ),
367410 )
368411 stem_stride = patch_size
369412 else :
370413 mid_chs = make_divisible (dims [0 ] // 2 ) if 'tiered' in stem_type else dims [0 ]
371414 self .stem = nn .Sequential (* filter (None , [
372- nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
415+ nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias , ** dd ),
373416 act_layer () if 'act' in stem_type else None ,
374- nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
375- norm_layer (dims [0 ]),
417+ nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias , ** dd ),
418+ norm_layer (dims [0 ], ** dd ),
376419 ]))
377420 stem_stride = 4
378421
@@ -406,6 +449,7 @@ def __init__(
406449 act_layer = act_layer ,
407450 norm_layer = norm_layer ,
408451 norm_layer_cl = norm_layer_cl ,
452+ ** dd ,
409453 ))
410454 prev_chs = out_chs
411455 # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
@@ -417,12 +461,13 @@ def __init__(
417461 # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
418462 if head_norm_first :
419463 assert not head_hidden_size
420- self .norm_pre = norm_layer (self .num_features )
464+ self .norm_pre = norm_layer (self .num_features , ** dd )
421465 self .head = ClassifierHead (
422466 self .num_features ,
423467 num_classes ,
424468 pool_type = global_pool ,
425469 drop_rate = self .drop_rate ,
470+ ** dd ,
426471 )
427472 else :
428473 self .norm_pre = nn .Identity ()
@@ -434,6 +479,7 @@ def __init__(
434479 drop_rate = self .drop_rate ,
435480 norm_layer = norm_layer ,
436481 act_layer = 'gelu' ,
482+ ** dd ,
437483 )
438484 self .head_hidden_size = self .head .num_features
439485 named_apply (partial (_init_weights , head_init_scale = head_init_scale ), self )
0 commit comments