1
+ import logging
2
+
1
3
import timm
2
4
import torch .nn as nn
3
5
import torch .nn .functional as F
4
6
5
7
from viscy .unet .networks .unext2 import StemDepthtoChannels
6
8
9
+ _logger = logging .getLogger ("lightning.pytorch" )
10
+
7
11
8
12
class ContrastiveEncoder (nn .Module ):
9
13
def __init__ (
@@ -15,33 +19,38 @@ def __init__(
15
19
embedding_len : int = 256 ,
16
20
stem_stride : int = 2 ,
17
21
predict : bool = False ,
22
+ drop_path_rate : float = 0.2 ,
18
23
):
24
+ """ContrastiveEncoder network that uses
25
+ ConvNext and ResNet backbons from timm.
26
+
27
+ :param str backbone: Backbone architecture for the encoder,
28
+ defaults to "convnext_tiny"
29
+ :param int in_channels: Number of input channels, defaults to 2
30
+ :param int in_stack_depth: Number of input slices in z-stack, defaults to 12
31
+ :param tuple[int, int, int] stem_kernel_size: 3D kernel size for the stem.
32
+ Input stack depth must be divisible by the kernel depth,
33
+ defaults to (5, 3, 3)
34
+ :param int embedding_len: Length of the embedding vector, defaults to 256
35
+ :param int stem_stride: stride of the stem, defaults to 2
36
+ :param bool predict: prediction mode, defaults to False
37
+ :param float drop_path_rate: probability that residual connections
38
+ are dropped during training, defaults to 0.2
39
+ """
19
40
super ().__init__ ()
20
-
21
41
self .predict = predict
22
42
self .backbone = backbone
23
43
24
- """
25
- ContrastiveEncoder network that uses ConvNext and ResNet backbons from timm.
26
-
27
- Parameters:
28
- - backbone (str): Backbone architecture for the encoder. Default is "convnext_tiny".
29
- - in_channels (int): Number of input channels. Default is 2.
30
- - in_stack_depth (int): Number of input slices in z-stack. Default is 15.
31
- - stem_kernel_size (tuple[int, int, int]): 3D kernel size for the stem. Input stack depth must be divisible by the kernel depth. Default is (5, 3, 3).
32
- - embedding_len (int): Length of the embedding. Default is 1000.
33
- """
34
-
35
44
encoder = timm .create_model (
36
45
backbone ,
37
46
pretrained = True ,
38
47
features_only = False ,
39
- drop_path_rate = 0.2 ,
48
+ drop_path_rate = drop_path_rate ,
40
49
num_classes = 3 * embedding_len ,
41
50
)
42
51
43
52
if "convnext" in backbone :
44
- print ( "Using ConvNext backbone." )
53
+ _logger . debug ( f "Using ConvNeXt backbone for { type ( self ). __name__ } ." )
45
54
46
55
in_channels_encoder = encoder .stem [0 ].out_channels
47
56
@@ -58,7 +67,7 @@ def __init__(
58
67
encoder .head .fc = nn .Identity ()
59
68
60
69
elif "resnet" in backbone :
61
- print ( "Using ResNet backbone. " )
70
+ _logger . debug ( f "Using ResNet backbone for { type ( self ). __name__ } " )
62
71
# Adapt stem and projection head of resnet here.
63
72
# replace the stem designed for RGB images with a stem designed to handle 3D multi-channel input.
64
73
@@ -73,7 +82,7 @@ def __init__(
73
82
encoder .fc = nn .Identity ()
74
83
75
84
# Create a new stem that can handle 3D multi-channel input.
76
- print ( "using stem kernel size" , stem_kernel_size )
85
+ _logger . debug ( f"Stem kernel size: { stem_kernel_size } " )
77
86
self .stem = StemDepthtoChannels (
78
87
in_channels , in_stack_depth , in_channels_encoder , stem_kernel_size
79
88
)
0 commit comments