Skip to content

Commit 0a72176

Browse files
ziw-liumattersoflightalishbaimranAlishba Imranduopeng
authored
Single-cell phenotyping with contrastive learning (#113)
* first draft of contrastive learning model * fixed stem and projection head, drafted lightning module * Contrastive_dataloader (#99) * initial dataloader.py * Update dataloader_test.py * Update dataloader_test.py * Update dataloader_test.py * Update dataloader_test.py * rename training script * move contrastive network to viscy.representation module * Update hcs.py * refactored class names * correct imports * cleaner names for model arch and module * new imports * Fixed epoch loss logging and WandB integration in ContrastiveModule * updated training_script.py * Update hcs.py * contrastive.py * engine.py * script to test data i/o speed from different filesystems * moved applications folder to viscy.applications so that pip install -e . works. * add resnet50 to ContrastiveEncoder * rename training_script.py to training_script_resnet.py * test dataloader on lustre and vast * move training_script_resnet to viscy.applications so that `pip install -e .` works * refined the tests for contrastive dataloader * sbatch script for dataloader * delete redundant module * nits: updated the model construction of contrastive resnet encoder. * Updated training script, HCS data handling, engine, and contrastive representation * Fix normalization, visualization issues, logging and multi-channel prediction * updated training and prediction * update training and prediction script * formatting * combine the application directories * lint * replace notebook with script * format script * rename scripts conflicting with pytest * lint application scripts * do not filter all warnings * log instead of print * split data modules by task * clean up imports * update typing * use pathlib * remove redundant file * updated predict.py * better typing * wip: triplet dataset * avoid forward ref this might increase code analysis time a tiny bit but should not have any effect at runtime * check that z range is valid and fix indexing * clean up and explain random sampling * sample dict instead of tuple and include track index * take out generic HCS methods for reuse * implement TripletDataModule * use new batch type in engine * better typing * read normalization metadata * docstring for data module * drop normalization metadata after transformation * remove unused import * fix initial crop size * Infection state (#118) * updated prediction code * updated predict code * updated code * fixed the stem and forward pass (#115) * fixed the stem and forward pass * update forward calls to encoder * self.encoder -> self.model * nits * l2 normalize projections * black compliance * black compliance * WIP: Save progress before merging * updated contrastive.py * stem update * updated predict code * Delete viscy/applications/contrastive_phenotyping/PCA.ipynb * pushing dataloader test updated * pca deleted * training and dataloader test * updated structure * deleted files * updated training merged files * removed commented code * removed uneeded code * removed uneeded code * removed comments * snake_case * fixed CI issues * removed num_fovs --------- Co-authored-by: Shalin Mehta <[email protected]> --------- Co-authored-by: Shalin Mehta <[email protected]> Co-authored-by: Alishba Imran <[email protected]> Co-authored-by: Alishba Imran <[email protected]> Co-authored-by: Alishba Imran <[email protected]> Co-authored-by: Duo Peng <[email protected]>
1 parent 9ff3ceb commit 0a72176

File tree

13 files changed

+2049
-63
lines changed

13 files changed

+2049
-63
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# %% Imports and paths.
2+
import timm
3+
import torch
4+
import torchview
5+
6+
from viscy.light.engine import ContrastiveModule
7+
from viscy.representation.contrastive import ContrastiveEncoder, UNeXt2Stem
8+
9+
# %load_ext autoreload
10+
# %autoreload 2
11+
# %% Initialize the model and log the graph.
12+
contra_model = ContrastiveEncoder(
13+
backbone="convnext_tiny"
14+
) # other options: convnext_tiny resnet50
15+
print(contra_model)
16+
model_graph = torchview.draw_graph(
17+
contra_model,
18+
torch.randn(1, 2, 15, 224, 224),
19+
depth=3, # adjust depth to zoom in.
20+
device="cpu",
21+
)
22+
# Print the image of the model.
23+
model_graph.resize_graph(scale=2.5)
24+
model_graph.visual_graph
25+
26+
# %% Initialize a resent50 model and log the graph.
27+
contra_model = ContrastiveEncoder(
28+
backbone="resnet50", in_stack_depth=16, stem_kernel_size=(4, 3, 3)
29+
) # note that the resnet first layer takes 64 channels (so we can't have multiples of 3)
30+
print(contra_model)
31+
model_graph = torchview.draw_graph(
32+
contra_model,
33+
torch.randn(1, 2, 16, 224, 224),
34+
depth=3, # adjust depth to zoom in.
35+
device="cpu",
36+
)
37+
# Print the image of the model.
38+
model_graph.resize_graph(scale=2.5)
39+
model_graph.visual_graph
40+
41+
42+
# %% Initiatlize the lightning module and view the model.
43+
contrastive_module = ContrastiveModule()
44+
print(contrastive_module.encoder)
45+
46+
# %%
47+
model_graph = torchview.draw_graph(
48+
contrastive_module.encoder,
49+
torch.randn(1, 2, 15, 200, 200),
50+
depth=3, # adjust depth to zoom in.
51+
device="cpu",
52+
)
53+
# Print the image of the model.
54+
model_graph.visual_graph
55+
56+
# %% Playground
57+
58+
available_models = timm.list_models(pretrained=True)
59+
60+
stem = UNeXt2Stem(
61+
in_channels=2, out_channels=96, kernel_size=(5, 2, 2), in_stack_depth=15
62+
)
63+
print(stem)
64+
stem_graph = torchview.draw_graph(
65+
stem,
66+
torch.randn(1, 2, 15, 256, 256),
67+
depth=2, # adjust depth to zoom in.
68+
device="cpu",
69+
)
70+
# Print the image of the model.
71+
stem_graph.visual_graph
72+
# %%
73+
encoder = timm.create_model(
74+
"convnext_tiny",
75+
pretrained=True,
76+
features_only=False,
77+
num_classes=200,
78+
)
79+
80+
print(encoder)
81+
82+
# %%
83+
84+
encoder.stem = stem
85+
86+
model_graph = torchview.draw_graph(
87+
encoder,
88+
torch.randn(1, 2, 15, 256, 256),
89+
depth=2, # adjust depth to zoom in.
90+
device="cpu",
91+
)
92+
# Print the image of the model.
93+
model_graph.visual_graph
94+
# %%
95+
encoder.stem = torch.nn.Identity()
96+
97+
encoder_graph = torchview.draw_graph(
98+
encoder,
99+
torch.randn(1, 96, 128, 128),
100+
depth=2, # adjust depth to zoom in.
101+
device="cpu",
102+
)
103+
# Print the image of the model.
104+
encoder_graph.visual_graph
105+
106+
# %%

0 commit comments

Comments
 (0)