Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions viscy/representation/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,22 @@

import timm
import torch.nn as nn
from monai.networks.nets.resnet import ResNetFeatures
from torch import Tensor

from viscy.unet.networks.unext2 import StemDepthtoChannels


def projection_mlp(in_dims: int, hidden_dims: int, out_dims: int) -> nn.Module:
return nn.Sequential(
nn.Linear(in_dims, hidden_dims),
nn.BatchNorm1d(hidden_dims),
nn.ReLU(inplace=True),
nn.Linear(hidden_dims, out_dims),
nn.BatchNorm1d(out_dims),
)


class ContrastiveEncoder(nn.Module):
"""
Contrastive encoder network that uses ConvNeXt v1 and ResNet backbones from timm.
Expand Down Expand Up @@ -63,12 +74,8 @@ def __init__(
in_channels_encoder = encoder.conv1.out_channels
encoder.conv1 = nn.Identity()
# Save projection head separately and erase the projection head contained within the encoder.
projection = nn.Sequential(
nn.Linear(encoder.head.fc.in_features, embedding_dim),
nn.BatchNorm1d(embedding_dim),
nn.ReLU(inplace=True),
nn.Linear(embedding_dim, projection_dim),
nn.BatchNorm1d(projection_dim),
projection = projection_mlp(
encoder.head.fc.in_features, embedding_dim, projection_dim
)
encoder.head.fc = nn.Identity()
# Create a new stem that can handle 3D multi-channel input.
Expand Down Expand Up @@ -102,3 +109,54 @@ def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
embedding = self.encoder(x)
projections = self.projection(embedding)
return (embedding, projections)


class ResNet3dEncoder(nn.Module):
"""
3D ResNet encoder network that uses MONAI's ResNetFeatures.

Parameters
----------
backbone : str
Name of the backbone model
in_channels : int, optional
Number of input channels
embedding_dim : int, optional
Embedded feature dimension that matches backbone output channels,
by default 512 (ResNet-18)
projection_dim : int, optional
Projection dimension for computing loss, by default 128
"""

def __init__(
self,
backbone: str,
in_channels: int = 1,
embedding_dim: int = 512,
projection_dim: int = 128,
) -> None:
super().__init__()
self.encoder = ResNetFeatures(
backbone, pretrained=True, spatial_dims=3, in_channels=in_channels
)
self.projection = projection_mlp(embedding_dim, embedding_dim, projection_dim)

def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
"""
Forward pass.

Parameters
----------
x : Tensor
Input image

Returns
-------
tuple[Tensor, Tensor]
The embedding tensor and the projection tensor
"""
feature_map = self.encoder(x)[-1]
embedding = self.encoder.avgpool(feature_map)
embedding = embedding.view(embedding.size(0), -1)
projections = self.projection(embedding)
return (embedding, projections)
54 changes: 36 additions & 18 deletions viscy/scripts/count_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,45 @@
import torch
from ptflops import get_model_complexity_info

from viscy.translation.engine import VSUNet
from viscy.representation.contrastive import ContrastiveEncoder, ResNet3dEncoder


# %%
def print_flops(model):
with torch.cuda.device(0):
macs, params = get_model_complexity_info(
model,
(1, 32, 128, 128), # print_per_layer_stat=False
)
print(macs, params)


# %%
model = VSUNet(
architecture="UNeXt2",
model_config={
"in_channels": 1,
"out_channels": 2,
"in_stack_depth": 5,
"backbone": "convnextv2_tiny",
"stem_kernel_size": (5, 4, 4),
"decoder_mode": "pixelshuffle",
"head_expansion_ratio": 4,
},
resnet_3d = ResNet3dEncoder(
"resnet10", in_channels=1, embedding_dim=512, projection_dim=32
)

print_flops(resnet_3d)

# %%
with torch.cuda.device(0):
macs, params = get_model_complexity_info(
model,
(1, 5, 2048, 2048), # print_per_layer_stat=False
)
print(macs, params)
with torch.inference_mode():
features = resnet_3d(torch.rand(1, 1, 32, 128, 128))

for f in features:
print(f.shape)

# %%
convnext_2d = ContrastiveEncoder(
backbone="convnext_tiny",
in_channels=1,
in_stack_depth=32,
stem_kernel_size=(4, 4, 4),
stem_stride=(4, 4, 4),
embedding_dim=768,
projection_dim=32,
drop_path_rate=0.0,
)

print_flops(convnext_2d)

# %%