Skip to content

Commit f674826

Browse files
committed
Merge branch 'main' into vcp-tutorials
2 parents e6e3048 + 89b2917 commit f674826

File tree

23 files changed

+2117
-2281
lines changed

23 files changed

+2117
-2281
lines changed
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
"""Use pre-trained ImageNet models to extract features from images."""
2+
3+
# %%
4+
import pandas as pd
5+
import seaborn as sns
6+
import timm
7+
import numpy as np
8+
import torch
9+
from sklearn.decomposition import PCA
10+
from sklearn.preprocessing import StandardScaler
11+
from tqdm import tqdm
12+
from pathlib import Path
13+
from sklearn.linear_model import LogisticRegression
14+
15+
from viscy.data.triplet import TripletDataModule
16+
from viscy.transforms import ScaleIntensityRangePercentilesd
17+
18+
# %%
19+
model = timm.create_model("convnext_tiny", pretrained=True).eval().to("cuda")
20+
21+
# %%
22+
dm = TripletDataModule(
23+
data_path="/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets/float_phase_ome_zarr_output_test.zarr",
24+
tracks_path="/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets/track_phase_ome_zarr_output_test.zarr",
25+
source_channel=["DIC"],
26+
z_range=(0, 1),
27+
batch_size=128,
28+
num_workers=8,
29+
initial_yx_patch_size=(128, 128),
30+
final_yx_patch_size=(128, 128),
31+
normalizations=[
32+
ScaleIntensityRangePercentilesd(
33+
keys=["DIC"], lower=50, upper=99, b_min=0.0, b_max=1.0
34+
)
35+
],
36+
)
37+
dm.prepare_data()
38+
dm.setup("predict")
39+
40+
# %%
41+
features = []
42+
indices = []
43+
44+
with torch.inference_mode():
45+
for batch in tqdm(dm.predict_dataloader()):
46+
image = batch["anchor"][:, :, 0]
47+
rgb_image = image.repeat(1, 3, 1, 1).to("cuda")
48+
features.append(model.forward_features(rgb_image))
49+
indices.append(batch["index"])
50+
51+
# %%
52+
pooled = torch.cat(features).mean(dim=(2, 3)).cpu().numpy()
53+
tracks = pd.concat([pd.DataFrame(idx) for idx in indices])
54+
55+
# %%
56+
scaled_features = StandardScaler().fit_transform(pooled)
57+
pca = PCA(n_components=2)
58+
pca_features = pca.fit_transform(scaled_features)
59+
60+
# %% add pooled to dataframe naming each column with feature_i
61+
for i, feature in enumerate(pooled.T):
62+
tracks[f"feature_{i}"] = feature
63+
# add pca features to dataframe naming each column with pca_i
64+
for i, feature in enumerate(pca_features.T):
65+
tracks[f"pca_{i}"] = feature
66+
67+
# # save the dataframe as csv
68+
# tracks.to_csv("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/code/ALFI/imagenet_pretrained_features.csv", index=False)
69+
70+
# %% load the dataframe
71+
# tracks = pd.read_csv("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/code/ALFI/imagenet_pretrained_features.csv")
72+
73+
# %% load annotations
74+
75+
ann_root = Path(
76+
"/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets"
77+
)
78+
ann_path = ann_root / "test_annotations.csv"
79+
annotation = pd.read_csv(ann_path)
80+
81+
# add division column from annotation to tracks
82+
tracks["division"] = annotation["division"]
83+
84+
# %%
85+
ax = sns.scatterplot(
86+
x=tracks["pca_0"],
87+
y=tracks["pca_1"],
88+
hue=tracks["division"],
89+
legend="full",
90+
)
91+
ax.set_xlabel("PCA1")
92+
ax.set_ylabel("PCA2")
93+
94+
# %% compute the accuracy of the model using a linear classifier
95+
96+
# remove rows with division = -1
97+
tracks = tracks[tracks["division"] != -1]
98+
99+
# dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/"
100+
data_train_val = tracks[
101+
tracks["fov_name"].str.contains("/0/0/0")
102+
| tracks["fov_name"].str.contains("/0/1/0")
103+
| tracks["fov_name"].str.contains("/0/2/0")
104+
]
105+
106+
data_test = tracks[
107+
tracks["fov_name"].str.contains("/0/3/0")
108+
| tracks["fov_name"].str.contains("/0/4/0")
109+
]
110+
111+
x_train = data_train_val.drop(
112+
columns=[
113+
"division",
114+
"fov_name",
115+
"t",
116+
"track_id",
117+
"id",
118+
"parent_id",
119+
"parent_track_id",
120+
"pca_0",
121+
"pca_1",
122+
]
123+
)
124+
y_train = data_train_val["division"]
125+
126+
# train a logistic regression model
127+
clf = LogisticRegression(random_state=0).fit(x_train, y_train)
128+
129+
# test the trained classifer on the other half of the data
130+
131+
x_test = data_test.drop(
132+
columns=[
133+
"division",
134+
"fov_name",
135+
"t",
136+
"track_id",
137+
"id",
138+
"parent_id",
139+
"parent_track_id",
140+
"pca_0",
141+
"pca_1",
142+
]
143+
)
144+
y_test = data_test["division"]
145+
146+
# predict the infection state for the testing set
147+
y_pred = clf.predict(x_test)
148+
149+
# compute the accuracy of the classifier
150+
151+
accuracy = np.mean(y_pred == y_test)
152+
# save the accuracy for final ploting
153+
print(f"Accuracy of model: {accuracy}")
154+
155+
# %%

0 commit comments

Comments
 (0)