diff --git a/examples/compbio/README.md b/examples/compbio/README.md new file mode 100644 index 000000000..2ec2f662d --- /dev/null +++ b/examples/compbio/README.md @@ -0,0 +1,2 @@ +# Demo scripts for hackathon + diff --git a/examples/compbio/ddp.py b/examples/compbio/ddp.py new file mode 100644 index 000000000..b8669de29 --- /dev/null +++ b/examples/compbio/ddp.py @@ -0,0 +1,80 @@ +""" +# Distributed training + +Demonstrate how to train a model using distributed data parallel (DDP) with PyTorch Lightning. +""" + +import os +from pathlib import Path + +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D +from viscy.transforms import NormalizeSampled, RandWeightedCropd + + +def main(): + dm = HCSDataModule( + data_path="/hpc/mydata/ziwen.liu/demo/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr", + source_channel=["Sensor", "Phase"], + target_channel=["Inf_mask"], + yx_patch_size=(128, 128), + split_ratio=0.5, + z_window_size=1, + architecture="2D", + num_workers=8, + batch_size=128, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=8, + spatial_size=[-1, 128, 128], + keys=["Sensor", "Phase", "Inf_mask"], + w_key="Inf_mask", + ) + ], + ) + dm.prepare_data() + dm.setup(stage="fit") + + model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + loss_function=torch.nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), + ) + log_dir = Path(os.getenv("MYDATA", "")) / "torch_demo" + trainer = Trainer( + accelerator="gpu", + strategy="ddp_find_unused_parameters_true", + precision=32, + num_nodes=1, + devices=2, + fast_dev_run=True, + max_epochs=100, + logger=TensorBoardLogger(save_dir=log_dir, version="interactive_demo"), + log_every_n_steps=10, + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint( + monitor="loss/validate", save_top_k=5, every_n_epochs=1, save_last=True + ), + ], + ) + + torch.set_float32_matmul_precision("high") + trainer.fit(model, dm) + + +if __name__ == "__main__": + main() diff --git a/examples/compbio/ddp.sh b/examples/compbio/ddp.sh new file mode 100644 index 000000000..ba4737bb1 --- /dev/null +++ b/examples/compbio/ddp.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +#SBATCH --job-name=ddp_train +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=2 +#SBATCH --gres=gpu:2 +#SBATCH --partition=gpu +#SBATCH --cpus-per-task=16 +#SBATCH --mem-per-cpu=7G +#SBATCH --time=0-12:00:00 + + +# debugging flags (optional) +# https://lightning.ai/docs/pytorch/stable/clouds/cluster_advanced.html +export NCCL_DEBUG=INFO +export PYTHONFAULTHANDLER=1 + + +module load anaconda/2022.05 +conda activate viscy + +srun python ddp.py \ No newline at end of file diff --git a/examples/compbio/torch_into.py b/examples/compbio/torch_into.py new file mode 100644 index 000000000..5e8434fd1 --- /dev/null +++ b/examples/compbio/torch_into.py @@ -0,0 +1,148 @@ +# %% [markdown] +""" +# Infected cell segmentation + +Interactive script to demonstrate PyTorch Lightning training +with a semantic segmentation task. +""" + +# %% +import matplotlib.pyplot as plt +import torch +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint +from lightning.pytorch.loggers import TensorBoardLogger +from skimage.color import label2rgb +from torchview import draw_graph + +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D +from viscy.transforms import NormalizeSampled, RandWeightedCropd + +# use tf32 for matmul +torch.set_float32_matmul_precision("high") + +# %% [markdown] +""" +## Dataset + +In this dataset, we have images of A549 cells infected with Dengue virus +in two channels: + +- The cells are engineered to express a fluorescent protein (viral sensor) +that translocates from the cytoplasm to the nucleus upon infection. +- Quantitative phase images are reconstructed from brightfield images. + +## Task +The goal is to identify infected and uninfected cells from these images. +For the training target, cell nuclei were segmented from virtual staining, +and manually labelled as infected (1) or uninfected (2), +while background was labelled as 0. +We will train a U-Net to predict these labels from the images. +Is is a semantic segmentation task, +where assign a label (class) to each pixel in the image. +""" + + +# %% +# setup datamodule +data_module = HCSDataModule( + data_path="/hpc/mydata/ziwen.liu/demo/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr", + source_channel=["Sensor", "Phase"], + target_channel=["Inf_mask"], + yx_patch_size=(128, 128), + split_ratio=0.5, + z_window_size=1, + architecture="2D", + num_workers=8, + batch_size=128, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=8, + spatial_size=[-1, 128, 128], + keys=["Sensor", "Phase", "Inf_mask"], + w_key="Inf_mask", + ) + ], +) + +data_module.prepare_data() +data_module.setup(stage="fit") + +# %% +# sample from training data +num_samples = 8 + +for batch in data_module.train_dataloader(): + image = batch["source"][:num_samples].numpy() + label = batch["target"][:num_samples].numpy().astype("uint8") + break + +# %% +# visualize the samples +fig, ax = plt.subplots(num_samples, 3, figsize=(3, 8)) + +for i in range(num_samples): + ax[i, 0].imshow(image[i, 0, 0], cmap="gray") + ax[i, 1].imshow(image[i, 1, 0], cmap="gray") + ax[i, 2].imshow(label2rgb(label[i, 0, 0], bg_label=0)) + +for a in ax.ravel(): + a.axis("off") + +fig.tight_layout() + +# %% +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + loss_function=torch.nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), +) + +# %% +model_graph = draw_graph( + model=model, + input_data=torch.rand(1, 2, 1, 128, 128), + graph_name="2D UNet", + roll=True, + depth=2, + device="cpu", +) + +model_graph.visual_graph + +# %% +trainer = Trainer( + accelerator="gpu", + precision=32, + devices=1, + num_nodes=1, + fast_dev_run=True, + max_epochs=100, + logger=TensorBoardLogger( + save_dir="/hpc/mydata/ziwen.liu/demo/logs", + version="interactive_demo", + log_graph=True, + ), + log_every_n_steps=10, + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint( + monitor="loss/validate", save_top_k=5, every_n_epochs=1, save_last=True + ), + ], +) + + +# %% +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py new file mode 100644 index 000000000..91702497c --- /dev/null +++ b/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py @@ -0,0 +1,154 @@ +# %% +import torch +import lightning.pytorch as pl +import torch.nn as nn + +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from viscy.transforms import RandWeightedCropd +from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_25D import SemanticSegUNet25D + +from iohub.ngff import open_ome_zarr + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" + +# find ratio of background, uninfected and infected pixels +zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", +) +in_chan_names = zarr_input.channel_names + +num_pixels_bkg = 0 +num_pixels_uninf = 0 +num_pixels_inf = 0 +num_pixels = 0 +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T,C,Z,Y,X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z*X*Y + +pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_sum = sum(pixel_ratio_1) +pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] + +# %% craete data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90"], + target_channel=["Inf_mask"], + yx_patch_size=[512, 512], + split_ratio=0.8, + z_window_size=5, + architecture="2.5D", + num_workers=3, + batch_size=32, + normalizations=[ + NormalizeSampled( + keys=["Phase","HSP90"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 512, 512], + keys=["Phase","HSP90"], + w_key="Inf_mask", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# Visualize the dataset and the batch using napari +# Set the display +# os.environ['DISPLAY'] = ':1' + +# # Create a napari viewer +# viewer = napari.Viewer() + +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# # Start the napari event loop +# napari.run() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet25D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% +# Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py new file mode 100644 index 000000000..52af46732 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py @@ -0,0 +1,118 @@ +# %% +import torch +import lightning.pytorch as pl +import torch.nn as nn + +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from viscy.transforms import RandWeightedCropd +from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Sensor", "Phase"], + target_channel=["Inf_mask"], + yx_patch_size=[128, 128], + split_ratio=0.8, + z_window_size=1, + architecture="2D", + num_workers=1, + batch_size=128, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=8, + spatial_size=[-1, 128, 128], + keys=["Sensor", "Phase", "Inf_mask"], + w_key="Inf_mask", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# Visualize the dataset and the batch using napari +# Set the display +# os.environ['DISPLAY'] = ':1' + +# # Create a napari viewer +# viewer = napari.Viewer() + +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# # Start the napari event loop +# napari.run() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", + name="logs_wPhase", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=100, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), +) + +print(model) +# %% +# Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py new file mode 100644 index 000000000..0ecd6bdd4 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py @@ -0,0 +1,156 @@ +# %% +# import sys +# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") +import torch +import lightning.pytorch as pl +import torch.nn as nn + +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from viscy.transforms import RandWeightedCropd +from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_covnext import SemanticSegUNet25D + +from iohub.ngff import open_ome_zarr + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" + +# find ratio of background, uninfected and infected pixels +zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", +) +in_chan_names = zarr_input.channel_names + +num_pixels_bkg = 0 +num_pixels_uninf = 0 +num_pixels_inf = 0 +num_pixels = 0 +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T,C,Z,Y,X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z*X*Y + +pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_sum = sum(pixel_ratio_1) +pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] + +# %% craete data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90", "phase_nucl_iqr","hsp90_skew"], + target_channel=["Inf_mask"], + yx_patch_size=[256, 256], + split_ratio=0.8, + z_window_size=5, + architecture="2.2D", + num_workers=3, + batch_size=16, + normalizations=[ + NormalizeSampled( + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 256, 256], + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], + w_key="Inf_mask", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# Visualize the dataset and the batch using napari +# Set the display +# os.environ['DISPLAY'] = ':1' + +# # Create a napari viewer +# viewer = napari.Viewer() + +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# # Start the napari event loop +# napari.run() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer`` +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet25D( + in_channels=4, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% +# Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection_25D.py b/viscy/scripts/infection_phenotyping/classify_infection_25D.py new file mode 100644 index 000000000..c78a7e8f0 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection_25D.py @@ -0,0 +1,335 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import cv2 + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops, label +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig +# Define a 25d unet model for infection classification as a lightning module. + +class SemanticSegUNet25D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet25d(in_channels=in_channels, out_channels=out_channels, num_blocks=4, num_block_layers=4) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2,2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection_2D.py b/viscy/scripts/infection_phenotyping/classify_infection_2D.py new file mode 100644 index 000000000..74a6038e9 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection_2D.py @@ -0,0 +1,353 @@ +# import torchview +from typing import Literal, Sequence + +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from matplotlib.pyplot import get_cmap +from monai.transforms import DivisiblePad +from skimage.exposure import rescale_intensity +from skimage.measure import label, regionprops +from torch import Tensor + +# from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample +from viscy.unet.networks.Unet2D import Unet2d + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1, 1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0, 1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1, 0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0, 0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig + + +# Define a 2d unet model for infection classification as a lightning module. + + +class SemanticSegUNet2D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2, 2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py new file mode 100644 index 000000000..2ba698eed --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py @@ -0,0 +1,347 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import cv2 + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops, label +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample +from viscy.light.engine import VSUNet + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig +# Define a 25d unet model for infection classification as a lightning module. + +class SemanticSegUNet25D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = VSUNet( + architecture="2.2D", + model_config={ + "in_channels": in_channels, + "out_channels": out_channels, + "in_stack_depth": 5, + "backbone": "convnextv2_tiny", + "stem_kernel_size": (5, 4, 4), + "decoder_mode": "pixelshuffle", + "head_expansion_ratio": 4, + }, + ) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2,2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) +# %% diff --git a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py new file mode 100644 index 000000000..783c13340 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py @@ -0,0 +1,56 @@ +# %% + +from viscy.light.predict_writer import HCSPredictionWriter +from viscy.data.hcs import HCSDataModule +import lightning.pytorch as pl +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from viscy.transforms import NormalizeSampled + +# %% # %% write the predictions to a zarr file + +pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +data_module = HCSDataModule( + data_path=pred_datapath, + source_channel=['Sensor','Phase'], + target_channel=['Inf_mask'], + split_ratio=0.8, + z_window_size=1, + architecture="2D", + num_workers=1, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Phase", "Sensor"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) +data_module.prepare_data() +data_module.setup(stage="predict") + +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", +) + +# %% perform prediction + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +trainer.predict( + model=model, + datamodule=data_module, + return_predictions=True, +) + +# %% diff --git a/viscy/scripts/infection_phenotyping/readme.md b/viscy/scripts/infection_phenotyping/readme.md new file mode 100644 index 000000000..74dbc5000 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/readme.md @@ -0,0 +1,7 @@ +# Infection Classification Model + +This repository contains the code for the infection classification model (`infection_classification_model.py`) used in the infection phenotyping project. + +## Overview + +The `infection_classification_model.py` file implements a machine learning model for classifying infections based on various features. The model is trained on a labeled dataset, either fluorescence or label-free images, and can be used to predict the infection type for new samples. \ No newline at end of file diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py new file mode 100644 index 000000000..5ed140946 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -0,0 +1,137 @@ +# %% +from viscy.data.hcs import HCSDataModule +import lightning.pytorch as pl +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from pytorch_lightning.loggers import TensorBoardLogger +from viscy.transforms import NormalizeSampled + +# %% test the model on the test set +test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +data_module = HCSDataModule( + data_path=test_datapath, + source_channel=['Sensor','Phase'], + target_channel=['Inf_mask'], + split_ratio=0.8, + z_window_size=1, + architecture="2D", + num_workers=0, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +data_module.setup(stage="test") + +# %% create trainer and input + +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", + name="logs_wPhase", +) + +trainer = pl.Trainer( + logger=logger, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", +) + +trainer.test(model=model, datamodule=data_module) + + + + +# # %% script to develop confusion matrix for infected cell classifier + +# from iohub.ngff import open_ome_zarr +# import numpy as np +# from skimage.measure import regionprops, label +# import cv2 +# import seaborn as sns +# import matplotlib.pyplot as plt + +# # %% load the predicted zarr and the human-in-loop annotations zarr + +# pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" +# test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +# pred_dataset = open_ome_zarr( +# pred_datapath, +# layout="hcs", +# mode="r+", +# ) +# chan_pred = pred_dataset.channel_names + +# test_dataset = open_ome_zarr( +# test_datapath, +# layout="hcs", +# mode="r+", +# ) +# chan_test = test_dataset.channel_names + +# # %% compute confusion matrix for one image +# for well_id, well_data in pred_dataset.wells(): +# well_name, well_no = well_id.split("/") + +# for pos_name, pos_data in well_data.positions(): + +# pred_data = pos_data.data +# pred_pos_data = pred_data.numpy() +# T,C,Z,X,Y = pred_pos_data.shape + +# test_data = test_dataset[well_id + "/" + pos_name + "/0"] +# test_pos_data = test_data.numpy() + +# # compute confusion matrix for each time point and add to total +# conf_mat = np.zeros((2, 2)) +# for time in range(T): +# pred_img = pred_pos_data[time, chan_pred.index("Inf_mask_prediction"), 0, : , :] +# test_img = test_pos_data[time, chan_test.index("Inf_mask"), 0, : , :] + +# test_img_rz = cv2.resize(test_img, dsize=(2016,2048), interpolation=cv2.INTER_NEAREST)# pred_img = +# pred_img = np.where(test_img_rz > 0, pred_img, 0) + +# # find objects in every image +# label_img = label(test_img_rz) +# regions_label = regionprops(label_img) + +# # store pixel id for every label in pred and test imgs +# for region in regions_label: +# if region.area > 0: +# row, col = region.centroid +# pred_id = pred_img[int(row), int(col)] +# test_id = test_img_rz[int(row), int(col)] +# if pred_id == 1 and test_id == 1: +# conf_mat[1,1] += 1 +# if pred_id == 1 and test_id == 2: +# conf_mat[1,0] += 1 +# if pred_id == 2 and test_id == 1: +# conf_mat[0,1] += 1 +# if pred_id == 2 and test_id == 2: +# conf_mat[0,0] += 1 + +# # display the confusion matrix +# ax= plt.subplot() +# sns.heatmap(conf_mat, annot=True, fmt='g', ax=ax); #annot=True to annotate cells, ftm='g' to disable scientific notation + +# # labels, title and ticks +# ax.set_xlabel('annotated labels');ax.set_ylabel('predicted labels'); +# ax.set_title('Confusion Matrix'); +# ax.xaxis.set_ticklabels(['infected', 'uninfected']); ax.yaxis.set_ticklabels(['infected', 'uninfected']); + + +# # %% +# %%