Skip to content

Commit 308392c

Browse files
ziw-liumattersoflight
authored andcommitted
Config-based prediction with Xarray-based output format (#132)
* use callback to write prediction embeddings * moving over the script to compute infection score from contrastive_update * delete unused stem module * organize scripts and CLIs for contrastive phenotyping * add dependencies for prediction * export embedding dataset reader function * add more plots to script * use real paths in predict config * do not require seaborn and umap-learn for base install * use relative path in example job script * add docstrings for embedding writer and reader * don't assign unused grid object * show time and id as hover data in interactive plot * fix typo * fix script to test data i/o * ignore accidental lightning_logs * add plotly and nbformat to visual dependencies * tweak predict cli example * add another plot type - raw features of random samples * comment on speed of clustermap * add prediction config example to specify log path * simplify env var in job script and match cpu count with config * vectorize string concatenation --------- Co-authored-by: Shalin Mehta <[email protected]>
1 parent fb2ec0f commit 308392c

19 files changed

+590
-227
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,7 @@ htmlcov/
4040
coverage.xml
4141
*.cover
4242
.hypothesis/
43-
.pytest_cache/
43+
.pytest_cache/
44+
45+
#lightning_logs directory
46+
lightning_logs/
File renamed without changes.
File renamed without changes.
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# %%
2+
from pathlib import Path
3+
4+
import numpy as np
5+
import pandas as pd
6+
import plotly.express as px
7+
import seaborn as sns
8+
from sklearn.preprocessing import StandardScaler
9+
from umap import UMAP
10+
11+
from viscy.light.embedding_writer import read_embedding_dataset
12+
13+
# %%
14+
dataset = read_embedding_dataset(
15+
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/2024_02_04-tokenized-drop_path_0_0.zarr"
16+
)
17+
dataset
18+
19+
# %%
20+
# load all unprojected features:
21+
features = dataset["features"]
22+
# or select a well:
23+
# features = features[features["fov_name"].str.contains("B/4")]
24+
features
25+
26+
# %%
27+
# examine raw features
28+
random_samples = np.random.randint(0, dataset.sizes["sample"], 700)
29+
# concatenate fov_name, track_id, and t to create a unique sample identifier
30+
sample_id = (
31+
features["fov_name"][random_samples]
32+
+ "-"
33+
+ features["track_id"][random_samples].astype(str)
34+
+ "-"
35+
+ features["t"][random_samples].astype(str)
36+
)
37+
px.imshow(
38+
features.values[random_samples],
39+
labels={
40+
"x": "feature",
41+
"y": "sample",
42+
"color": "value",
43+
}, # change labels to match our metadata
44+
y=sample_id,
45+
# show fov_name as y-axis
46+
)
47+
48+
# %%
49+
scaled_features = StandardScaler().fit_transform(features.values)
50+
51+
umap = UMAP()
52+
53+
embedding = umap.fit_transform(scaled_features)
54+
features = (
55+
features.assign_coords(UMAP1=("sample", embedding[:, 0]))
56+
.assign_coords(UMAP2=("sample", embedding[:, 1]))
57+
.set_index(sample=["UMAP1", "UMAP2"], append=True)
58+
)
59+
features
60+
61+
# %%
62+
sns.scatterplot(
63+
x=features["UMAP1"], y=features["UMAP2"], hue=features["t"], s=7, alpha=0.8
64+
)
65+
66+
67+
# %%
68+
def load_annotation(da, path, name, categories: dict | None = None):
69+
annotation = pd.read_csv(path)
70+
annotation["fov_name"] = "/" + annotation["fov ID"]
71+
annotation = annotation.set_index(["fov_name", "id"])
72+
mi = pd.MultiIndex.from_arrays(
73+
[da["fov_name"].values, da["id"].values], names=["fov_name", "id"]
74+
)
75+
selected = annotation.loc[mi][name]
76+
if categories:
77+
selected = selected.astype("category").cat.rename_categories(categories)
78+
return selected
79+
80+
81+
# %%
82+
ann_root = Path(
83+
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track"
84+
)
85+
86+
infection = load_annotation(
87+
features,
88+
ann_root / "tracking_v1_infection.csv",
89+
"infection class",
90+
{0.0: "background", 1.0: "uninfected", 2.0: "infected"},
91+
)
92+
division = load_annotation(
93+
features,
94+
ann_root / "cell_division_state.csv",
95+
"division",
96+
{0: "non-dividing", 2: "dividing"},
97+
)
98+
99+
100+
# %%
101+
sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=division, s=7, alpha=0.8)
102+
103+
# %%
104+
sns.scatterplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, s=7, alpha=0.8)
105+
106+
# %%
107+
ax = sns.histplot(x=features["UMAP1"], y=features["UMAP2"], hue=infection, bins=64)
108+
sns.move_legend(ax, loc="lower left")
109+
110+
# %%
111+
sns.displot(
112+
x=features["UMAP1"],
113+
y=features["UMAP2"],
114+
kind="hist",
115+
col=infection,
116+
bins=64,
117+
cmap="inferno",
118+
)
119+
120+
# %%
121+
# interactive scatter plot to associate clusters with specific cells
122+
123+
px.scatter(
124+
data_frame=pd.DataFrame(
125+
{k: v for k, v in features.coords.items() if k != "features"}
126+
),
127+
x="UMAP1",
128+
y="UMAP2",
129+
color=(infection.astype(str) + " " + division.astype(str)).rename("annotation"),
130+
hover_name="fov_name",
131+
hover_data=["id", "t"],
132+
)
133+
134+
# %%
135+
# cluster features in heatmap directly
136+
# this is very slow for large datasets even with fastcluster installed
137+
inf_codes = pd.Series(infection.values.codes, name="infection")
138+
lut = dict(zip(inf_codes.unique(), "brw"))
139+
row_colors = inf_codes.map(lut)
140+
141+
g = sns.clustermap(
142+
scaled_features, row_colors=row_colors.to_numpy(), col_cluster=False, cbar_pos=None
143+
)
144+
g.yaxis.set_ticks([])
145+
# %%
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
seed_everything: 42
2+
trainer:
3+
accelerator: gpu
4+
strategy: auto
5+
devices: auto
6+
num_nodes: 1
7+
precision: 32-true
8+
callbacks:
9+
- class_path: viscy.light.embedding_writer.EmbeddingWriter
10+
init_args:
11+
output_path: "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/predict/test_prediction_code.zarr"
12+
# edit the following lines to specify logging path
13+
# - class_path: lightning.pytorch.loggers.TensorBoardLogger
14+
# init_args:
15+
# save_dir: /path/to/save_dir
16+
# version: name-of-experiment
17+
# log_graph: True
18+
inference_mode: true
19+
model:
20+
backbone: convnext_tiny
21+
in_channels: 2
22+
in_stack_depth: 15
23+
stem_kernel_size: [5, 4, 4]
24+
data:
25+
data_path: /hpc/projects/virtual_staining/2024_02_04_A549_DENV_ZIKV_timelapse/registered_chunked.zarr
26+
tracks_path: /hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr
27+
source_channel:
28+
- Phase3D
29+
- RFP
30+
z_range: [28, 43]
31+
batch_size: 32
32+
num_workers: 16
33+
initial_yx_patch_size: [192, 192]
34+
final_yx_patch_size: [192, 192]
35+
normalizations:
36+
- class_path: viscy.transforms.NormalizeSampled
37+
init_args:
38+
keys: [Phase3D]
39+
level: fov_statistics
40+
subtrahend: mean
41+
divisor: std
42+
- class_path: viscy.transforms.ScaleIntensityRangePercentilesd
43+
init_args:
44+
keys: [RFP]
45+
lower: 50
46+
upper: 99
47+
b_min: 0.0
48+
b_max: 1.0
49+
return_predictions: false
50+
ckpt_path: /hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/contrastive_tune_augmentations/lightning_logs/tokenized-drop-path-0.0/checkpoints/epoch=96-step=23377.ckpt
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#!/bin/bash
2+
3+
#SBATCH --job-name=contrastive_predict
4+
#SBATCH --nodes=1
5+
#SBATCH --ntasks-per-node=1
6+
#SBATCH --gres=gpu:1
7+
#SBATCH --partition=gpu
8+
#SBATCH --cpus-per-task=16
9+
#SBATCH --mem-per-cpu=7G
10+
#SBATCH --time=0-01:00:00
11+
12+
module load anaconda/2022.05
13+
# Update to use the actual prefix
14+
conda activate $MYDATA/envs/viscy
15+
16+
scontrol show job $SLURM_JOB_ID
17+
18+
# use absolute path in production
19+
config=./predict.yml
20+
cat $config
21+
srun python -m viscy.cli.contrastive_triplet predict -c $config
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from argparse import ArgumentParser
2+
from pathlib import Path
3+
import numpy as np
4+
import os
5+
import torch
6+
from torch.utils.data import DataLoader
7+
from tqdm import tqdm
8+
from viscy.data.triplet import TripletDataModule, TripletDataset
9+
import pandas as pd
10+
import warnings
11+
12+
warnings.filterwarnings(
13+
"ignore",
14+
category=UserWarning,
15+
message="To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).",
16+
)
17+
18+
# %% Paths and constants
19+
save_dir = (
20+
"/hpc/mydata/alishba.imran/VisCy/applications/contrastive_phenotyping/embeddings4"
21+
)
22+
23+
# rechunked data
24+
data_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/2.2-register_annotations/updated_all_annotations.zarr"
25+
26+
# updated tracking data
27+
tracks_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"
28+
29+
source_channel = ["background_mask", "uninfected_mask", "infected_mask"]
30+
z_range = (0, 1)
31+
batch_size = 1 # match the number of fovs being processed such that no data is left
32+
# set to 15 for full, 12 for infected, and 8 for uninfected
33+
34+
# non-rechunked data
35+
data_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"
36+
37+
# updated tracking data
38+
tracks_path_1 = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_02_04_A549_DENV_ZIKV_timelapse/7.1-seg_track/tracking_v1.zarr"
39+
40+
source_channel_1 = ["Nuclei_prediction_labels"]
41+
42+
43+
# %% Define the main function for training
44+
def main(hparams):
45+
# Initialize the data module for prediction, re-do embeddings but with size 224 by 224
46+
data_module = TripletDataModule(
47+
data_path=data_path,
48+
tracks_path=tracks_path,
49+
source_channel=source_channel,
50+
z_range=z_range,
51+
initial_yx_patch_size=(224, 224),
52+
final_yx_patch_size=(224, 224),
53+
batch_size=batch_size,
54+
num_workers=hparams.num_workers,
55+
)
56+
57+
data_module.setup(stage="predict")
58+
59+
print(f"Total prediction dataset size: {len(data_module.predict_dataset)}")
60+
61+
dataloader = DataLoader(
62+
data_module.predict_dataset,
63+
batch_size=batch_size,
64+
num_workers=hparams.num_workers,
65+
)
66+
67+
# Initialize the second data module for segmentation masks
68+
seg_data_module = TripletDataModule(
69+
data_path=data_path_1,
70+
tracks_path=tracks_path_1,
71+
source_channel=source_channel_1,
72+
z_range=z_range,
73+
initial_yx_patch_size=(224, 224),
74+
final_yx_patch_size=(224, 224),
75+
batch_size=batch_size,
76+
num_workers=hparams.num_workers,
77+
)
78+
79+
seg_data_module.setup(stage="predict")
80+
81+
seg_dataloader = DataLoader(
82+
seg_data_module.predict_dataset,
83+
batch_size=batch_size,
84+
num_workers=hparams.num_workers,
85+
)
86+
87+
# Initialize lists to store average values
88+
background_avg = []
89+
uninfected_avg = []
90+
infected_avg = []
91+
92+
for batch, seg_batch in tqdm(
93+
zip(dataloader, seg_dataloader),
94+
desc="Processing batches",
95+
total=len(data_module.predict_dataset),
96+
):
97+
anchor = batch["anchor"]
98+
seg_anchor = seg_batch["anchor"].int()
99+
100+
# Extract the fov_name and id from the batch
101+
fov_name = batch["index"]["fov_name"][0]
102+
cell_id = batch["index"]["id"].item()
103+
104+
fov_dirs = fov_name.split("/")
105+
# Construct the path to the CSV file
106+
csv_path = os.path.join(
107+
tracks_path, *fov_dirs, f"tracks{fov_name.replace('/', '_')}.csv"
108+
)
109+
110+
# Read the CSV file
111+
df = pd.read_csv(csv_path)
112+
113+
# Find the row with the specified id and extract the track_id
114+
track_id = df.loc[df["id"] == cell_id, "track_id"].values[0]
115+
116+
# Create a boolean mask where segmentation values are equal to the track_id
117+
mask = seg_anchor == track_id
118+
# mask = (seg_anchor > 0)
119+
120+
# Find the most frequent non-zero value in seg_anchor
121+
# unique, counts = np.unique(seg_anchor[seg_anchor > 0], return_counts=True)
122+
# most_frequent_value = unique[np.argmax(counts)]
123+
124+
# # Create a boolean mask where segmentation values are equal to the most frequent value
125+
# mask = (seg_anchor == most_frequent_value)
126+
127+
# Expand the mask to match the anchor tensor shape
128+
mask = mask.expand(1, 3, 1, 224, 224)
129+
130+
# Calculate average values for each channel (background, uninfected, infected) using the mask
131+
background_avg.append(anchor[:, 0, :, :, :][mask[:, 0]].mean().item())
132+
uninfected_avg.append(anchor[:, 1, :, :, :][mask[:, 1]].mean().item())
133+
infected_avg.append(anchor[:, 2, :, :, :][mask[:, 2]].mean().item())
134+
135+
# Convert lists to numpy arrays
136+
background_avg = np.array(background_avg)
137+
uninfected_avg = np.array(uninfected_avg)
138+
infected_avg = np.array(infected_avg)
139+
140+
print("Average values per cell for each mask calculated.")
141+
print("Background average shape:", background_avg.shape)
142+
print("Uninfected average shape:", uninfected_avg.shape)
143+
print("Infected average shape:", infected_avg.shape)
144+
145+
# Save the averages as .npy files
146+
np.save(os.path.join(save_dir, "background_avg.npy"), background_avg)
147+
np.save(os.path.join(save_dir, "uninfected_avg.npy"), uninfected_avg)
148+
np.save(os.path.join(save_dir, "infected_avg.npy"), infected_avg)
149+
150+
151+
if __name__ == "__main__":
152+
parser = ArgumentParser()
153+
parser.add_argument("--backbone", type=str, default="resnet50")
154+
parser.add_argument("--margin", type=float, default=0.5)
155+
parser.add_argument("--lr", type=float, default=1e-3)
156+
parser.add_argument("--schedule", type=str, default="Constant")
157+
parser.add_argument("--log_steps_per_epoch", type=int, default=10)
158+
parser.add_argument("--embedding_len", type=int, default=256)
159+
parser.add_argument("--max_epochs", type=int, default=100)
160+
parser.add_argument("--accelerator", type=str, default="gpu")
161+
parser.add_argument("--devices", type=int, default=1)
162+
parser.add_argument("--num_nodes", type=int, default=1)
163+
parser.add_argument("--log_every_n_steps", type=int, default=1)
164+
parser.add_argument("--num_workers", type=int, default=8)
165+
args = parser.parse_args()
166+
main(args)

0 commit comments

Comments
 (0)