Skip to content

Commit 9678911

Browse files
ziw-liuedyoshikun
authored andcommitted
Configurable drop path rate in contrastive models (#131)
* log instead of print * configurable drop path rate * fix docstring
1 parent 147a61a commit 9678911

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

viscy/light/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,7 @@ def __init__(
584584
stem_kernel_size: tuple[int, int, int] = (5, 4, 4),
585585
embedding_len: int = 256,
586586
predict: bool = False,
587+
drop_path_rate: float = 0.2,
587588
tracks_path: str = "data/tracks",
588589
features_output_path: str = "",
589590
projections_output_path: str = "",
@@ -615,6 +616,7 @@ def __init__(
615616
stem_kernel_size=stem_kernel_size,
616617
embedding_len=embedding_len,
617618
predict=predict,
619+
drop_path_rate=drop_path_rate,
618620
)
619621
self.example_input_array = torch.rand(
620622
1, in_channels, in_stack_depth, *example_input_yx_shape

viscy/representation/contrastive.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import logging
2+
13
import timm
24
import torch.nn as nn
35
import torch.nn.functional as F
46

57
from viscy.unet.networks.unext2 import StemDepthtoChannels
68

9+
_logger = logging.getLogger("lightning.pytorch")
10+
711

812
class ContrastiveEncoder(nn.Module):
913
def __init__(
@@ -15,33 +19,38 @@ def __init__(
1519
embedding_len: int = 256,
1620
stem_stride: int = 2,
1721
predict: bool = False,
22+
drop_path_rate: float = 0.2,
1823
):
24+
"""ContrastiveEncoder network that uses
25+
ConvNext and ResNet backbons from timm.
26+
27+
:param str backbone: Backbone architecture for the encoder,
28+
defaults to "convnext_tiny"
29+
:param int in_channels: Number of input channels, defaults to 2
30+
:param int in_stack_depth: Number of input slices in z-stack, defaults to 12
31+
:param tuple[int, int, int] stem_kernel_size: 3D kernel size for the stem.
32+
Input stack depth must be divisible by the kernel depth,
33+
defaults to (5, 3, 3)
34+
:param int embedding_len: Length of the embedding vector, defaults to 256
35+
:param int stem_stride: stride of the stem, defaults to 2
36+
:param bool predict: prediction mode, defaults to False
37+
:param float drop_path_rate: probability that residual connections
38+
are dropped during training, defaults to 0.2
39+
"""
1940
super().__init__()
20-
2141
self.predict = predict
2242
self.backbone = backbone
2343

24-
"""
25-
ContrastiveEncoder network that uses ConvNext and ResNet backbons from timm.
26-
27-
Parameters:
28-
- backbone (str): Backbone architecture for the encoder. Default is "convnext_tiny".
29-
- in_channels (int): Number of input channels. Default is 2.
30-
- in_stack_depth (int): Number of input slices in z-stack. Default is 15.
31-
- stem_kernel_size (tuple[int, int, int]): 3D kernel size for the stem. Input stack depth must be divisible by the kernel depth. Default is (5, 3, 3).
32-
- embedding_len (int): Length of the embedding. Default is 1000.
33-
"""
34-
3544
encoder = timm.create_model(
3645
backbone,
3746
pretrained=True,
3847
features_only=False,
39-
drop_path_rate=0.2,
48+
drop_path_rate=drop_path_rate,
4049
num_classes=3 * embedding_len,
4150
)
4251

4352
if "convnext" in backbone:
44-
print("Using ConvNext backbone.")
53+
_logger.debug(f"Using ConvNeXt backbone for {type(self).__name__}.")
4554

4655
in_channels_encoder = encoder.stem[0].out_channels
4756

@@ -58,7 +67,7 @@ def __init__(
5867
encoder.head.fc = nn.Identity()
5968

6069
elif "resnet" in backbone:
61-
print("Using ResNet backbone.")
70+
_logger.debug(f"Using ResNet backbone for {type(self).__name__}")
6271
# Adapt stem and projection head of resnet here.
6372
# replace the stem designed for RGB images with a stem designed to handle 3D multi-channel input.
6473

@@ -73,7 +82,7 @@ def __init__(
7382
encoder.fc = nn.Identity()
7483

7584
# Create a new stem that can handle 3D multi-channel input.
76-
print("using stem kernel size", stem_kernel_size)
85+
_logger.debug(f"Stem kernel size: {stem_kernel_size}")
7786
self.stem = StemDepthtoChannels(
7887
in_channels, in_stack_depth, in_channels_encoder, stem_kernel_size
7988
)

0 commit comments

Comments
 (0)