diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 8edeb8623..1b08e089a 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -2,11 +2,22 @@ import timm import torch.nn as nn +from monai.networks.nets.resnet import ResNetFeatures from torch import Tensor from viscy.unet.networks.unext2 import StemDepthtoChannels +def projection_mlp(in_dims: int, hidden_dims: int, out_dims: int) -> nn.Module: + return nn.Sequential( + nn.Linear(in_dims, hidden_dims), + nn.BatchNorm1d(hidden_dims), + nn.ReLU(inplace=True), + nn.Linear(hidden_dims, out_dims), + nn.BatchNorm1d(out_dims), + ) + + class ContrastiveEncoder(nn.Module): """ Contrastive encoder network that uses ConvNeXt v1 and ResNet backbones from timm. @@ -63,12 +74,8 @@ def __init__( in_channels_encoder = encoder.conv1.out_channels encoder.conv1 = nn.Identity() # Save projection head separately and erase the projection head contained within the encoder. - projection = nn.Sequential( - nn.Linear(encoder.head.fc.in_features, embedding_dim), - nn.BatchNorm1d(embedding_dim), - nn.ReLU(inplace=True), - nn.Linear(embedding_dim, projection_dim), - nn.BatchNorm1d(projection_dim), + projection = projection_mlp( + encoder.head.fc.in_features, embedding_dim, projection_dim ) encoder.head.fc = nn.Identity() # Create a new stem that can handle 3D multi-channel input. @@ -102,3 +109,54 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: embedding = self.encoder(x) projections = self.projection(embedding) return (embedding, projections) + + +class ResNet3dEncoder(nn.Module): + """ + 3D ResNet encoder network that uses MONAI's ResNetFeatures. + + Parameters + ---------- + backbone : str + Name of the backbone model + in_channels : int, optional + Number of input channels + embedding_dim : int, optional + Embedded feature dimension that matches backbone output channels, + by default 512 (ResNet-18) + projection_dim : int, optional + Projection dimension for computing loss, by default 128 + """ + + def __init__( + self, + backbone: str, + in_channels: int = 1, + embedding_dim: int = 512, + projection_dim: int = 128, + ) -> None: + super().__init__() + self.encoder = ResNetFeatures( + backbone, pretrained=True, spatial_dims=3, in_channels=in_channels + ) + self.projection = projection_mlp(embedding_dim, embedding_dim, projection_dim) + + def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: + """ + Forward pass. + + Parameters + ---------- + x : Tensor + Input image + + Returns + ------- + tuple[Tensor, Tensor] + The embedding tensor and the projection tensor + """ + feature_map = self.encoder(x)[-1] + embedding = self.encoder.avgpool(feature_map) + embedding = embedding.view(embedding.size(0), -1) + projections = self.projection(embedding) + return (embedding, projections) diff --git a/viscy/scripts/count_flops.py b/viscy/scripts/count_flops.py index ec4610899..fa9ee7747 100644 --- a/viscy/scripts/count_flops.py +++ b/viscy/scripts/count_flops.py @@ -2,27 +2,45 @@ import torch from ptflops import get_model_complexity_info -from viscy.translation.engine import VSUNet +from viscy.representation.contrastive import ContrastiveEncoder, ResNet3dEncoder + + +# %% +def print_flops(model): + with torch.cuda.device(0): + macs, params = get_model_complexity_info( + model, + (1, 32, 128, 128), # print_per_layer_stat=False + ) + print(macs, params) + # %% -model = VSUNet( - architecture="UNeXt2", - model_config={ - "in_channels": 1, - "out_channels": 2, - "in_stack_depth": 5, - "backbone": "convnextv2_tiny", - "stem_kernel_size": (5, 4, 4), - "decoder_mode": "pixelshuffle", - "head_expansion_ratio": 4, - }, +resnet_3d = ResNet3dEncoder( + "resnet10", in_channels=1, embedding_dim=512, projection_dim=32 ) +print_flops(resnet_3d) + # %% -with torch.cuda.device(0): - macs, params = get_model_complexity_info( - model, - (1, 5, 2048, 2048), # print_per_layer_stat=False - ) -print(macs, params) +with torch.inference_mode(): + features = resnet_3d(torch.rand(1, 1, 32, 128, 128)) + +for f in features: + print(f.shape) + +# %% +convnext_2d = ContrastiveEncoder( + backbone="convnext_tiny", + in_channels=1, + in_stack_depth=32, + stem_kernel_size=(4, 4, 4), + stem_stride=(4, 4, 4), + embedding_dim=768, + projection_dim=32, + drop_path_rate=0.0, +) + +print_flops(convnext_2d) + # %%